1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: jit"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport torch 4*da0073e9SAndroid Build Coastguard Worker 5*da0073e9SAndroid Build Coastguard Worker# This is how we include tests located in test/jit/... 6*da0073e9SAndroid Build Coastguard Worker# They are included here so that they are invoked when you call `test_jit.py`, 7*da0073e9SAndroid Build Coastguard Worker# do not run these test files directly. 8*da0073e9SAndroid Build Coastguard Workerfrom jit.test_tracer import TestTracer, TestMixTracingScripting # noqa: F401 9*da0073e9SAndroid Build Coastguard Workerfrom jit.test_recursive_script import TestRecursiveScript # noqa: F401 10*da0073e9SAndroid Build Coastguard Workerfrom jit.test_type_sharing import TestTypeSharing # noqa: F401 11*da0073e9SAndroid Build Coastguard Workerfrom jit.test_logging import TestLogging # noqa: F401 12*da0073e9SAndroid Build Coastguard Workerfrom jit.test_backends import TestBackends, TestBackendsWithCompiler # noqa: F401 13*da0073e9SAndroid Build Coastguard Workerfrom jit.test_backend_nnapi import TestNnapiBackend # noqa: F401 14*da0073e9SAndroid Build Coastguard Workerfrom jit.test_list_dict import TestList, TestDict, TestNamedTuple, TestScriptDict, TestScriptList # noqa: F401 15*da0073e9SAndroid Build Coastguard Workerfrom jit.test_async import TestAsync # noqa: F401 16*da0073e9SAndroid Build Coastguard Workerfrom jit.test_await import TestAwait # noqa: F401 17*da0073e9SAndroid Build Coastguard Workerfrom jit.test_data_parallel import TestDataParallel # noqa: F401 18*da0073e9SAndroid Build Coastguard Workerfrom jit.test_models import TestModels # noqa: F401 19*da0073e9SAndroid Build Coastguard Workerfrom jit.test_modules import TestModules # noqa: F401 20*da0073e9SAndroid Build Coastguard Workerfrom jit.test_autodiff import TestAutodiffJit # noqa: F401 21*da0073e9SAndroid Build Coastguard Workerfrom jit.test_autodiff_subgraph_slicing import TestAutodiffSubgraphSlicing # noqa: F401 22*da0073e9SAndroid Build Coastguard Workerfrom jit.test_custom_operators import TestCustomOperators # noqa: F401 23*da0073e9SAndroid Build Coastguard Workerfrom jit.test_graph_rewrite_passes import TestGraphRewritePasses # noqa: F401 24*da0073e9SAndroid Build Coastguard Workerfrom jit.test_class_type import TestClassType # noqa: F401 25*da0073e9SAndroid Build Coastguard Workerfrom jit.test_builtins import TestBuiltins, TestTensorBuiltins # noqa: F401 26*da0073e9SAndroid Build Coastguard Workerfrom jit.test_ignore_context_manager import TestIgnoreContextManager # noqa: F401 27*da0073e9SAndroid Build Coastguard Workerfrom jit.test_symbolic_shape_analysis import TestSymbolicShapeAnalysis # noqa: F401 28*da0073e9SAndroid Build Coastguard Workerfrom jit.test_op_decompositions import TestOpDecompositions # noqa: F401 29*da0073e9SAndroid Build Coastguard Workerfrom jit.test_unsupported_ops import TestUnsupportedOps # noqa: F401 30*da0073e9SAndroid Build Coastguard Workerfrom jit.test_freezing import TestFreezing, TestFrozenOptimizations, TestMKLDNNReinplacing # noqa: F401 31*da0073e9SAndroid Build Coastguard Workerfrom jit.test_peephole import TestPeephole # noqa: F401 32*da0073e9SAndroid Build Coastguard Workerfrom jit.test_alias_analysis import TestAliasAnalysis # noqa: F401 33*da0073e9SAndroid Build Coastguard Workerfrom jit.test_save_load import TestSaveLoad, TestSaveLoadFlatbuffer # noqa: F401 34*da0073e9SAndroid Build Coastguard Workerfrom jit.test_save_load_for_op_version import TestSaveLoadForOpVersion # noqa: F401 35*da0073e9SAndroid Build Coastguard Workerfrom jit.test_module_containers import TestModuleContainers # noqa: F401 36*da0073e9SAndroid Build Coastguard Workerfrom jit.test_python_bindings import TestPythonBindings # noqa: F401 37*da0073e9SAndroid Build Coastguard Workerfrom jit.test_python_ir import TestPythonIr # noqa: F401 38*da0073e9SAndroid Build Coastguard Workerfrom jit.test_functional_blocks import TestFunctionalBlocks # noqa: F401 39*da0073e9SAndroid Build Coastguard Workerfrom jit.test_remove_mutation import TestRemoveMutation # noqa: F401 40*da0073e9SAndroid Build Coastguard Workerfrom jit.test_torchbind import TestTorchbind # noqa: F401 41*da0073e9SAndroid Build Coastguard Workerfrom jit.test_module_interface import TestModuleInterface # noqa: F401 42*da0073e9SAndroid Build Coastguard Workerfrom jit.test_with import TestWith # noqa: F401 43*da0073e9SAndroid Build Coastguard Workerfrom jit.test_enum import TestEnum # noqa: F401 44*da0073e9SAndroid Build Coastguard Workerfrom jit.test_string_formatting import TestStringFormatting # noqa: F401 45*da0073e9SAndroid Build Coastguard Workerfrom jit.test_profiler import TestProfiler # noqa: F401 46*da0073e9SAndroid Build Coastguard Workerfrom jit.test_slice import TestSlice # noqa: F401 47*da0073e9SAndroid Build Coastguard Workerfrom jit.test_ignorable_args import TestIgnorableArgs # noqa: F401 48*da0073e9SAndroid Build Coastguard Workerfrom jit.test_hooks import TestHooks # noqa: F401 49*da0073e9SAndroid Build Coastguard Workerfrom jit.test_warn import TestWarn # noqa: F401 50*da0073e9SAndroid Build Coastguard Workerfrom jit.test_isinstance import TestIsinstance # noqa: F401 51*da0073e9SAndroid Build Coastguard Workerfrom jit.test_cuda import TestCUDA # noqa: F401 52*da0073e9SAndroid Build Coastguard Workerfrom jit.test_python_builtins import TestPythonBuiltinOP # noqa: F401 53*da0073e9SAndroid Build Coastguard Workerfrom jit.test_typing import TestTyping # noqa: F401 54*da0073e9SAndroid Build Coastguard Workerfrom jit.test_hash import TestHash # noqa: F401 55*da0073e9SAndroid Build Coastguard Workerfrom jit.test_complex import TestComplex # noqa: F401 56*da0073e9SAndroid Build Coastguard Workerfrom jit.test_jit_utils import TestJitUtils # noqa: F401 57*da0073e9SAndroid Build Coastguard Workerfrom jit.test_scriptmod_ann import TestScriptModuleInstanceAttributeTypeAnnotation # noqa: F401 58*da0073e9SAndroid Build Coastguard Workerfrom jit.test_types import TestTypesAndAnnotation # noqa: F401 59*da0073e9SAndroid Build Coastguard Workerfrom jit.test_misc import TestMisc # noqa: F401 60*da0073e9SAndroid Build Coastguard Workerfrom jit.test_upgraders import TestUpgraders # noqa: F401 61*da0073e9SAndroid Build Coastguard Workerfrom jit.test_pdt import TestPDT # noqa: F401 62*da0073e9SAndroid Build Coastguard Workerfrom jit.test_tensor_creation_ops import TestTensorCreationOps # noqa: F401 63*da0073e9SAndroid Build Coastguard Workerfrom jit.test_module_apis import TestModuleAPIs # noqa: F401 64*da0073e9SAndroid Build Coastguard Workerfrom jit.test_script_profile import TestScriptProfile # noqa: F401 65*da0073e9SAndroid Build Coastguard Workerfrom jit.test_convert_activation import TestFunctionalToInplaceActivation, TestInplaceToFunctionalActivation # noqa: F401 66*da0073e9SAndroid Build Coastguard Workerfrom jit.test_parametrization import TestParametrization # noqa: F401 67*da0073e9SAndroid Build Coastguard Workerfrom jit.test_attr import TestGetDefaultAttr # noqa: F401 68*da0073e9SAndroid Build Coastguard Workerfrom jit.test_aten_pow import TestAtenPow # noqa: F401 69*da0073e9SAndroid Build Coastguard Workerfrom jit.test_optimize_for_mobile_preserve_debug_info import TestOptimizeForMobilePreserveDebugInfo # noqa: F401 70*da0073e9SAndroid Build Coastguard Workerfrom jit.test_union import TestUnion # noqa: F401 71*da0073e9SAndroid Build Coastguard Workerfrom jit.test_batch_mm import TestBatchMM # noqa: F401 72*da0073e9SAndroid Build Coastguard Workerfrom jit.test_dtype_analysis import TestDtypeAnalysis, TestDtypeCustomRulesCPU # noqa: F401 73*da0073e9SAndroid Build Coastguard Workerfrom jit.test_device_analysis import TestDeviceAnalysis # noqa: F401 74*da0073e9SAndroid Build Coastguard Workerfrom jit.test_dce import TestDCE # noqa: F401 75*da0073e9SAndroid Build Coastguard Workerfrom jit.test_sparse import TestSparse # noqa: F401 76*da0073e9SAndroid Build Coastguard Workerfrom jit.test_tensor_methods import TestTensorMethods # noqa: F401 77*da0073e9SAndroid Build Coastguard Workerfrom jit.test_dataclasses import TestDataclasses # noqa: F401 78*da0073e9SAndroid Build Coastguard Workerfrom jit.test_generator import TestGenerator # noqa: F401 79*da0073e9SAndroid Build Coastguard Worker 80*da0073e9SAndroid Build Coastguard Worker# Torch 81*da0073e9SAndroid Build Coastguard Workerfrom torch import Tensor 82*da0073e9SAndroid Build Coastguard Workerfrom torch._C import TensorType, BoolType, parse_ir, _propagate_shapes 83*da0073e9SAndroid Build Coastguard Workerfrom torch.autograd import Variable 84*da0073e9SAndroid Build Coastguard Workerfrom torch.jit.annotations import BroadcastingList2, BroadcastingList3, Any # noqa: F401 85*da0073e9SAndroid Build Coastguard Workerfrom torch.nn.utils.rnn import PackedSequence 86*da0073e9SAndroid Build Coastguard Workerfrom torch.testing import FileCheck, make_tensor 87*da0073e9SAndroid Build Coastguard Workerimport torch.autograd.profiler 88*da0073e9SAndroid Build Coastguard Workerimport torch.cuda 89*da0073e9SAndroid Build Coastguard Workerimport torch.jit 90*da0073e9SAndroid Build Coastguard Workerimport torch.jit._logging 91*da0073e9SAndroid Build Coastguard Workerimport torch.jit.frontend 92*da0073e9SAndroid Build Coastguard Workerimport torch.nn as nn 93*da0073e9SAndroid Build Coastguard Workerimport torch.nn.functional as F 94*da0073e9SAndroid Build Coastguard Worker 95*da0073e9SAndroid Build Coastguard Worker# Testing utils 96*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal import jit_utils 97*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_jit import check_against_reference 98*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import run_tests, IS_WINDOWS, TEST_WITH_UBSAN, \ 99*da0073e9SAndroid Build Coastguard Worker suppress_warnings, IS_SANDCASTLE, GRAPH_EXECUTOR, ProfilingMode, TestCase, \ 100*da0073e9SAndroid Build Coastguard Worker freeze_rng_state, slowTest, TemporaryFileName, \ 101*da0073e9SAndroid Build Coastguard Worker enable_profiling_mode_for_profiling_tests, TEST_MKL, set_default_dtype, num_profiled_runs, \ 102*da0073e9SAndroid Build Coastguard Worker skipIfCrossRef, skipIfTorchDynamo 103*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.jit_utils import JitTestCase, enable_cpu_fuser, disable_autodiff_subgraph_inlining, \ 104*da0073e9SAndroid Build Coastguard Worker _trace, do_input_map, get_execution_plan, make_global, \ 105*da0073e9SAndroid Build Coastguard Worker execWrapper, _inline_everything, _tmp_donotuse_dont_inline_everything, \ 106*da0073e9SAndroid Build Coastguard Worker RUN_CUDA 107*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.jit_metaprogramming_utils import ( 108*da0073e9SAndroid Build Coastguard Worker get_script_args, 109*da0073e9SAndroid Build Coastguard Worker create_input, unpack_variables, 110*da0073e9SAndroid Build Coastguard Worker additional_module_tests, EXCLUDE_SCRIPT_MODULES, 111*da0073e9SAndroid Build Coastguard Worker get_nn_module_name_from_kwargs, get_nn_mod_test_name, script_method_template) 112*da0073e9SAndroid Build Coastguard Worker 113*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_nn import module_tests, new_module_tests, criterion_tests 114*da0073e9SAndroid Build Coastguard Worker 115*da0073e9SAndroid Build Coastguard Worker# For testing truediv in python 2 116*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.test_module.future_div import div_int_future, div_float_future 117*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.test_module.no_future_div import div_int_nofuture, div_float_nofuture 118*da0073e9SAndroid Build Coastguard Worker 119*da0073e9SAndroid Build Coastguard Worker# Standard library 120*da0073e9SAndroid Build Coastguard Workerfrom collections import defaultdict, namedtuple, OrderedDict 121*da0073e9SAndroid Build Coastguard Workerfrom copy import deepcopy 122*da0073e9SAndroid Build Coastguard Workerfrom itertools import product 123*da0073e9SAndroid Build Coastguard Workerfrom textwrap import dedent 124*da0073e9SAndroid Build Coastguard Workerfrom typing import List, Dict, NamedTuple, Optional, Tuple, Union 125*da0073e9SAndroid Build Coastguard Workerimport copy 126*da0073e9SAndroid Build Coastguard Workerimport functools 127*da0073e9SAndroid Build Coastguard Workerimport inspect 128*da0073e9SAndroid Build Coastguard Workerimport io 129*da0073e9SAndroid Build Coastguard Workerimport itertools 130*da0073e9SAndroid Build Coastguard Workerimport math 131*da0073e9SAndroid Build Coastguard Workerimport numpy as np 132*da0073e9SAndroid Build Coastguard Workerimport os 133*da0073e9SAndroid Build Coastguard Workerimport pickle 134*da0073e9SAndroid Build Coastguard Workerimport pickletools 135*da0073e9SAndroid Build Coastguard Workerimport random 136*da0073e9SAndroid Build Coastguard Workerimport re 137*da0073e9SAndroid Build Coastguard Workerimport shutil 138*da0073e9SAndroid Build Coastguard Workerimport string 139*da0073e9SAndroid Build Coastguard Workerimport sys 140*da0073e9SAndroid Build Coastguard Workerimport tempfile 141*da0073e9SAndroid Build Coastguard Workerimport types 142*da0073e9SAndroid Build Coastguard Workerimport typing 143*da0073e9SAndroid Build Coastguard Workerimport unittest 144*da0073e9SAndroid Build Coastguard Workerimport warnings 145*da0073e9SAndroid Build Coastguard Workerimport zipfile 146*da0073e9SAndroid Build Coastguard Workerimport tracemalloc 147*da0073e9SAndroid Build Coastguard Worker 148*da0073e9SAndroid Build Coastguard Worker 149*da0073e9SAndroid Build Coastguard Workerdef canonical(graph): 150*da0073e9SAndroid Build Coastguard Worker return torch._C._jit_pass_canonicalize(graph).str(False) 151*da0073e9SAndroid Build Coastguard Worker 152*da0073e9SAndroid Build Coastguard Workerdef LSTMCellF(input, hx, cx, *params): 153*da0073e9SAndroid Build Coastguard Worker return LSTMCell(input, (hx, cx), *params) 154*da0073e9SAndroid Build Coastguard Worker 155*da0073e9SAndroid Build Coastguard Workerdef doAutodiffCheck(testname): 156*da0073e9SAndroid Build Coastguard Worker # TODO: setting false on test itself is not working 157*da0073e9SAndroid Build Coastguard Worker if "test_t_" in testname or testname == "test_t": 158*da0073e9SAndroid Build Coastguard Worker return False 159*da0073e9SAndroid Build Coastguard Worker 160*da0073e9SAndroid Build Coastguard Worker if GRAPH_EXECUTOR == ProfilingMode.SIMPLE: 161*da0073e9SAndroid Build Coastguard Worker return False 162*da0073e9SAndroid Build Coastguard Worker 163*da0073e9SAndroid Build Coastguard Worker if GRAPH_EXECUTOR == ProfilingMode.LEGACY: 164*da0073e9SAndroid Build Coastguard Worker return True 165*da0073e9SAndroid Build Coastguard Worker 166*da0073e9SAndroid Build Coastguard Worker 167*da0073e9SAndroid Build Coastguard Worker # these tests are disabled because BailOut nodes 168*da0073e9SAndroid Build Coastguard Worker # inserted by ProfilingExecutor interfere with 169*da0073e9SAndroid Build Coastguard Worker # subgraph slicing of Differentiable Graphs 170*da0073e9SAndroid Build Coastguard Worker test_exceptions = ( 171*da0073e9SAndroid Build Coastguard Worker # functional 172*da0073e9SAndroid Build Coastguard Worker 'test_nn_dropout', 173*da0073e9SAndroid Build Coastguard Worker 'test_nn_log_softmax', 174*da0073e9SAndroid Build Coastguard Worker 'test_nn_relu', 175*da0073e9SAndroid Build Coastguard Worker 'test_nn_softmax', 176*da0073e9SAndroid Build Coastguard Worker 'test_nn_threshold', 177*da0073e9SAndroid Build Coastguard Worker 'test_nn_lp_pool2d', 178*da0073e9SAndroid Build Coastguard Worker 'test_nn_lp_pool1d', 179*da0073e9SAndroid Build Coastguard Worker 'test_nn_gumbel_softmax_hard', 180*da0073e9SAndroid Build Coastguard Worker 'test_nn_gumbel_softmax', 181*da0073e9SAndroid Build Coastguard Worker 'test_nn_multilabel_soft_margin_loss', 182*da0073e9SAndroid Build Coastguard Worker 'test_nn_batch_norm', 183*da0073e9SAndroid Build Coastguard Worker 'test_nn_max_pool2d_with_indices', 184*da0073e9SAndroid Build Coastguard Worker # AutogradJitGenerated 185*da0073e9SAndroid Build Coastguard Worker 'test___rdiv___constant', 186*da0073e9SAndroid Build Coastguard Worker 'test___rdiv___scalar_constant', 187*da0073e9SAndroid Build Coastguard Worker 'test_split', 188*da0073e9SAndroid Build Coastguard Worker 'test_split_dim', 189*da0073e9SAndroid Build Coastguard Worker 'test_split_dim_neg0', 190*da0073e9SAndroid Build Coastguard Worker 'test_split_size_list', 191*da0073e9SAndroid Build Coastguard Worker 'test_split_size_list_dim', 192*da0073e9SAndroid Build Coastguard Worker 'test_split_size_list_dim_neg0', 193*da0073e9SAndroid Build Coastguard Worker 'test_split_with_sizes', 194*da0073e9SAndroid Build Coastguard Worker 'test_split_with_sizes_dim', 195*da0073e9SAndroid Build Coastguard Worker 'test_split_with_sizes_dim_neg0', 196*da0073e9SAndroid Build Coastguard Worker 'test_split_with_sizes_size_0', 197*da0073e9SAndroid Build Coastguard Worker 'test_nn_max_pool2d_with_indices', 198*da0073e9SAndroid Build Coastguard Worker ) 199*da0073e9SAndroid Build Coastguard Worker 200*da0073e9SAndroid Build Coastguard Worker return testname not in test_exceptions 201*da0073e9SAndroid Build Coastguard Worker 202*da0073e9SAndroid Build Coastguard Worker 203*da0073e9SAndroid Build Coastguard Worker# TODO: enable TE in PE when all tests are fixed 204*da0073e9SAndroid Build Coastguard Workertorch._C._jit_set_texpr_fuser_enabled(GRAPH_EXECUTOR == ProfilingMode.PROFILING) 205*da0073e9SAndroid Build Coastguard Workertorch._C._jit_set_profiling_executor(GRAPH_EXECUTOR != ProfilingMode.LEGACY) 206*da0073e9SAndroid Build Coastguard Worker 207*da0073e9SAndroid Build Coastguard Workerdef LSTMCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None): 208*da0073e9SAndroid Build Coastguard Worker hx, cx = hidden 209*da0073e9SAndroid Build Coastguard Worker gates = F.linear(input, w_ih, b_ih) + F.linear(hx, w_hh, b_hh) 210*da0073e9SAndroid Build Coastguard Worker 211*da0073e9SAndroid Build Coastguard Worker ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) 212*da0073e9SAndroid Build Coastguard Worker ingate = torch.sigmoid(ingate) 213*da0073e9SAndroid Build Coastguard Worker forgetgate = torch.sigmoid(forgetgate) 214*da0073e9SAndroid Build Coastguard Worker cellgate = torch.tanh(cellgate) 215*da0073e9SAndroid Build Coastguard Worker outgate = torch.sigmoid(outgate) 216*da0073e9SAndroid Build Coastguard Worker 217*da0073e9SAndroid Build Coastguard Worker cy = (forgetgate * cx) + (ingate * cellgate) 218*da0073e9SAndroid Build Coastguard Worker hy = outgate * torch.tanh(cy) 219*da0073e9SAndroid Build Coastguard Worker return hy, cy 220*da0073e9SAndroid Build Coastguard Worker 221*da0073e9SAndroid Build Coastguard Worker 222*da0073e9SAndroid Build Coastguard Workerdef LSTMCellC(*args, **kwargs): 223*da0073e9SAndroid Build Coastguard Worker hy, cy = LSTMCellF(*args, **kwargs) 224*da0073e9SAndroid Build Coastguard Worker return torch.cat((hy, cy)) 225*da0073e9SAndroid Build Coastguard Worker 226*da0073e9SAndroid Build Coastguard Worker 227*da0073e9SAndroid Build Coastguard Workerdef LSTMCellS(x, hx, cx, w_ih, w_hh, b_ih, b_hh): 228*da0073e9SAndroid Build Coastguard Worker gates = x.mm(w_ih.t()) + hx.mm(w_hh.t()) + b_ih + b_hh 229*da0073e9SAndroid Build Coastguard Worker ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) 230*da0073e9SAndroid Build Coastguard Worker ingate = torch.sigmoid(ingate) 231*da0073e9SAndroid Build Coastguard Worker forgetgate = torch.sigmoid(forgetgate) 232*da0073e9SAndroid Build Coastguard Worker cellgate = torch.tanh(cellgate) 233*da0073e9SAndroid Build Coastguard Worker outgate = torch.sigmoid(outgate) 234*da0073e9SAndroid Build Coastguard Worker cy = (forgetgate * cx) + (ingate * cellgate) 235*da0073e9SAndroid Build Coastguard Worker hy = outgate * torch.tanh(cy) 236*da0073e9SAndroid Build Coastguard Worker return hy, cy 237*da0073e9SAndroid Build Coastguard Worker 238*da0073e9SAndroid Build Coastguard Worker 239*da0073e9SAndroid Build Coastguard Worker# Code reference: https://github.com/pytorch/translate/blob/master/pytorch_translate/rnn_cell.py#L27:44 240*da0073e9SAndroid Build Coastguard Workerdef MiLSTMCell(x, hx, cx, w_ih, w_hh, alpha, beta_i, beta_h, bias): 241*da0073e9SAndroid Build Coastguard Worker Wx = x.mm(w_ih.t()) 242*da0073e9SAndroid Build Coastguard Worker Uz = hx.mm(w_hh.t()) 243*da0073e9SAndroid Build Coastguard Worker # Section 2.1 in https://arxiv.org/pdf/1606.06630.pdf 244*da0073e9SAndroid Build Coastguard Worker gates = alpha * Wx * Uz + beta_i * Wx + beta_h * Uz + bias 245*da0073e9SAndroid Build Coastguard Worker # Same as LSTMCell after this point 246*da0073e9SAndroid Build Coastguard Worker ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) 247*da0073e9SAndroid Build Coastguard Worker ingate = ingate.sigmoid() 248*da0073e9SAndroid Build Coastguard Worker forgetgate = forgetgate.sigmoid() 249*da0073e9SAndroid Build Coastguard Worker cellgate = cellgate.tanh() 250*da0073e9SAndroid Build Coastguard Worker outgate = outgate.sigmoid() 251*da0073e9SAndroid Build Coastguard Worker cy = (forgetgate * cx) + (ingate * cellgate) 252*da0073e9SAndroid Build Coastguard Worker hy = outgate * cy.tanh() 253*da0073e9SAndroid Build Coastguard Worker return hy, cy 254*da0073e9SAndroid Build Coastguard Worker 255*da0073e9SAndroid Build Coastguard Worker 256*da0073e9SAndroid Build Coastguard Worker 257*da0073e9SAndroid Build Coastguard Workerdef get_lstm_inputs(device, training=False, seq_length=None): 258*da0073e9SAndroid Build Coastguard Worker input_shape = (3, 10) if seq_length is None else (seq_length, 3, 10) 259*da0073e9SAndroid Build Coastguard Worker input = torch.randn(*input_shape, dtype=torch.float, device=device, requires_grad=training) 260*da0073e9SAndroid Build Coastguard Worker hx = torch.randn(3, 20, dtype=torch.float, device=device, requires_grad=training) 261*da0073e9SAndroid Build Coastguard Worker cx = torch.randn(3, 20, dtype=torch.float, device=device, requires_grad=training) 262*da0073e9SAndroid Build Coastguard Worker module = nn.LSTMCell(10, 20).to(device, torch.float) # Just to allocate weights with correct sizes 263*da0073e9SAndroid Build Coastguard Worker if training: 264*da0073e9SAndroid Build Coastguard Worker params = tuple(module.parameters()) 265*da0073e9SAndroid Build Coastguard Worker else: 266*da0073e9SAndroid Build Coastguard Worker params = tuple(p.requires_grad_(False) for p in module.parameters()) 267*da0073e9SAndroid Build Coastguard Worker return (input, hx, cx) + params 268*da0073e9SAndroid Build Coastguard Worker 269*da0073e9SAndroid Build Coastguard Worker 270*da0073e9SAndroid Build Coastguard Workerdef get_milstm_inputs(device, training=False): 271*da0073e9SAndroid Build Coastguard Worker minibatch = 3 272*da0073e9SAndroid Build Coastguard Worker input_size = 10 273*da0073e9SAndroid Build Coastguard Worker hidden_size = 20 274*da0073e9SAndroid Build Coastguard Worker x = torch.randn(minibatch, input_size, device=device, dtype=torch.float) 275*da0073e9SAndroid Build Coastguard Worker hx = torch.randn(minibatch, hidden_size, device=device, dtype=torch.float) 276*da0073e9SAndroid Build Coastguard Worker cx = torch.randn(minibatch, hidden_size, device=device, dtype=torch.float) 277*da0073e9SAndroid Build Coastguard Worker 278*da0073e9SAndroid Build Coastguard Worker ih = torch.randn(4 * hidden_size, input_size, device=device, dtype=torch.float, requires_grad=training) 279*da0073e9SAndroid Build Coastguard Worker hh = torch.randn(4 * hidden_size, hidden_size, device=device, dtype=torch.float, requires_grad=training) 280*da0073e9SAndroid Build Coastguard Worker alpha = torch.randn(4 * hidden_size, dtype=torch.float, device=device, requires_grad=training) 281*da0073e9SAndroid Build Coastguard Worker ibeta = torch.randn(4 * hidden_size, dtype=torch.float, device=device, requires_grad=training) 282*da0073e9SAndroid Build Coastguard Worker hbeta = torch.randn(4 * hidden_size, dtype=torch.float, device=device, requires_grad=training) 283*da0073e9SAndroid Build Coastguard Worker bias = torch.randn(4 * hidden_size, dtype=torch.float, device=device, requires_grad=training) 284*da0073e9SAndroid Build Coastguard Worker return x, hx, cx, ih, hh, alpha, ibeta, hbeta, bias 285*da0073e9SAndroid Build Coastguard Worker 286*da0073e9SAndroid Build Coastguard Worker 287*da0073e9SAndroid Build Coastguard Workerdef get_fn(file_name, script_path): 288*da0073e9SAndroid Build Coastguard Worker import importlib.util 289*da0073e9SAndroid Build Coastguard Worker spec = importlib.util.spec_from_file_location(file_name, script_path) 290*da0073e9SAndroid Build Coastguard Worker module = importlib.util.module_from_spec(spec) 291*da0073e9SAndroid Build Coastguard Worker spec.loader.exec_module(module) 292*da0073e9SAndroid Build Coastguard Worker fn = module.fn 293*da0073e9SAndroid Build Coastguard Worker return fn 294*da0073e9SAndroid Build Coastguard Worker 295*da0073e9SAndroid Build Coastguard Workerdef get_grad_executor(plan_state, diff_graph_idx=None, skip_check=False): 296*da0073e9SAndroid Build Coastguard Worker if diff_graph_idx is None: 297*da0073e9SAndroid Build Coastguard Worker nodes = list(plan_state.graph.nodes()) 298*da0073e9SAndroid Build Coastguard Worker 299*da0073e9SAndroid Build Coastguard Worker if not skip_check: 300*da0073e9SAndroid Build Coastguard Worker nodes = list(filter(lambda n : n.kind() != "prim::BailOut" and n.kind() != "prim::BailoutTemplate", nodes)) 301*da0073e9SAndroid Build Coastguard Worker if len(nodes) == 1 or (len(nodes) == 2 and nodes[1].kind() == "prim::TupleConstruct"): 302*da0073e9SAndroid Build Coastguard Worker pass 303*da0073e9SAndroid Build Coastguard Worker elif len(nodes) == 2 and nodes[0].kind() == "prim::RequiresGradCheck" and nodes[1].kind() == "prim::If": 304*da0073e9SAndroid Build Coastguard Worker pass 305*da0073e9SAndroid Build Coastguard Worker else: 306*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("Can't get a grad_executor for a non-differentiable graph") 307*da0073e9SAndroid Build Coastguard Worker grad_executors = list(plan_state.code.grad_executor_states()) 308*da0073e9SAndroid Build Coastguard Worker return grad_executors[diff_graph_idx or 0] 309*da0073e9SAndroid Build Coastguard Worker 310*da0073e9SAndroid Build Coastguard Worker 311*da0073e9SAndroid Build Coastguard Workerdef all_backward_graphs(script_module, diff_graph_idx=None): 312*da0073e9SAndroid Build Coastguard Worker # Note: for Python 2 the order seems to be unstable 313*da0073e9SAndroid Build Coastguard Worker ge_state = script_module.get_debug_state() 314*da0073e9SAndroid Build Coastguard Worker fwd_plan = get_execution_plan(ge_state) 315*da0073e9SAndroid Build Coastguard Worker grad_executor_state = get_grad_executor(fwd_plan, diff_graph_idx=diff_graph_idx) 316*da0073e9SAndroid Build Coastguard Worker bwd_plans = list(grad_executor_state.execution_plans.values()) 317*da0073e9SAndroid Build Coastguard Worker return [p.graph.copy() for p in bwd_plans] 318*da0073e9SAndroid Build Coastguard Worker 319*da0073e9SAndroid Build Coastguard Worker 320*da0073e9SAndroid Build Coastguard Workerdef backward_graph(script_module, diff_graph_idx=None, skip_check=False): 321*da0073e9SAndroid Build Coastguard Worker ge_state = script_module.get_debug_state() 322*da0073e9SAndroid Build Coastguard Worker fwd_plan = get_execution_plan(ge_state) 323*da0073e9SAndroid Build Coastguard Worker grad_executor_state = get_grad_executor(fwd_plan, diff_graph_idx=diff_graph_idx, skip_check=skip_check) 324*da0073e9SAndroid Build Coastguard Worker bwd_plan = get_execution_plan(grad_executor_state) 325*da0073e9SAndroid Build Coastguard Worker # Running JIT passes requires that we own the graph (with a shared_ptr). 326*da0073e9SAndroid Build Coastguard Worker # The debug state struct does not own its graph so we make a copy of it. 327*da0073e9SAndroid Build Coastguard Worker return bwd_plan.graph.copy() 328*da0073e9SAndroid Build Coastguard Worker 329*da0073e9SAndroid Build Coastguard Worker 330*da0073e9SAndroid Build Coastguard Worker# helper function to get sum of List[Tensor] 331*da0073e9SAndroid Build Coastguard Workerdef _sum_of_list(tensorlist): 332*da0073e9SAndroid Build Coastguard Worker s = 0 333*da0073e9SAndroid Build Coastguard Worker for t in tensorlist: 334*da0073e9SAndroid Build Coastguard Worker s += t.sum() 335*da0073e9SAndroid Build Coastguard Worker return s 336*da0073e9SAndroid Build Coastguard Worker 337*da0073e9SAndroid Build Coastguard Worker 338*da0073e9SAndroid Build Coastguard Worker# has to be at top level or Pickle complains 339*da0073e9SAndroid Build Coastguard Workerclass FooToPickle(torch.nn.Module): 340*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 341*da0073e9SAndroid Build Coastguard Worker super().__init__() 342*da0073e9SAndroid Build Coastguard Worker self.bar = torch.jit.ScriptModule() 343*da0073e9SAndroid Build Coastguard Worker 344*da0073e9SAndroid Build Coastguard Worker 345*da0073e9SAndroid Build Coastguard Workerclass TestJitProfiler(JitTestCase): 346*da0073e9SAndroid Build Coastguard Worker """ 347*da0073e9SAndroid Build Coastguard Worker This runs tests that requires setting some global states like torch._C._set_graph_executor_optimize 348*da0073e9SAndroid Build Coastguard Worker and restore the values afterward, i.e. test_profiler. This is to address the flaky issue in 349*da0073e9SAndroid Build Coastguard Worker https://github.com/pytorch/pytorch/issues/91483 in which test_profiler was flaky and failed in the 350*da0073e9SAndroid Build Coastguard Worker middle without the chance to restore torch._C._set_graph_executor_optimize to its original value. 351*da0073e9SAndroid Build Coastguard Worker This causes issues for all future tests running after. 352*da0073e9SAndroid Build Coastguard Worker 353*da0073e9SAndroid Build Coastguard Worker Using a separate test class here, so that there is no need to run setup and teardown for all tests 354*da0073e9SAndroid Build Coastguard Worker in TestJit. 355*da0073e9SAndroid Build Coastguard Worker """ 356*da0073e9SAndroid Build Coastguard Worker 357*da0073e9SAndroid Build Coastguard Worker def setUp(self): 358*da0073e9SAndroid Build Coastguard Worker super().setUp() 359*da0073e9SAndroid Build Coastguard Worker self.graph_executor_optimize_opt = torch._C._get_graph_executor_optimize() 360*da0073e9SAndroid Build Coastguard Worker 361*da0073e9SAndroid Build Coastguard Worker def tearDown(self): 362*da0073e9SAndroid Build Coastguard Worker super().tearDown() 363*da0073e9SAndroid Build Coastguard Worker # Resetting 364*da0073e9SAndroid Build Coastguard Worker torch._C._set_graph_executor_optimize( 365*da0073e9SAndroid Build Coastguard Worker self.graph_executor_optimize_opt 366*da0073e9SAndroid Build Coastguard Worker ) 367*da0073e9SAndroid Build Coastguard Worker 368*da0073e9SAndroid Build Coastguard Worker def test_profiler(self): 369*da0073e9SAndroid Build Coastguard Worker torch._C._set_graph_executor_optimize(False) 370*da0073e9SAndroid Build Coastguard Worker 371*da0073e9SAndroid Build Coastguard Worker def other_fn(x): 372*da0073e9SAndroid Build Coastguard Worker return x * 2 373*da0073e9SAndroid Build Coastguard Worker 374*da0073e9SAndroid Build Coastguard Worker x = torch.rand(3, 4) 375*da0073e9SAndroid Build Coastguard Worker traced_other_fn = torch.jit.trace(other_fn, x) 376*da0073e9SAndroid Build Coastguard Worker 377*da0073e9SAndroid Build Coastguard Worker def fn(x): 378*da0073e9SAndroid Build Coastguard Worker y = traced_other_fn(x) 379*da0073e9SAndroid Build Coastguard Worker fut = torch.jit._fork(traced_other_fn, x) 380*da0073e9SAndroid Build Coastguard Worker y = torch.jit._wait(fut) 381*da0073e9SAndroid Build Coastguard Worker return y 382*da0073e9SAndroid Build Coastguard Worker 383*da0073e9SAndroid Build Coastguard Worker traced_fn = torch.jit.trace(fn, x) 384*da0073e9SAndroid Build Coastguard Worker with torch.autograd.profiler.profile() as prof: 385*da0073e9SAndroid Build Coastguard Worker traced_fn(x) 386*da0073e9SAndroid Build Coastguard Worker 387*da0073e9SAndroid Build Coastguard Worker # expecting to see other_fn TS function call 388*da0073e9SAndroid Build Coastguard Worker # with cpu time >= mul cpu time and 389*da0073e9SAndroid Build Coastguard Worker # a forked other_fn 390*da0073e9SAndroid Build Coastguard Worker 391*da0073e9SAndroid Build Coastguard Worker mul_events = defaultdict(int) 392*da0073e9SAndroid Build Coastguard Worker other_fn_events = defaultdict(int) 393*da0073e9SAndroid Build Coastguard Worker for e in prof.function_events: 394*da0073e9SAndroid Build Coastguard Worker if e.name == "aten::mul": 395*da0073e9SAndroid Build Coastguard Worker self.assertTrue(e.thread not in mul_events) 396*da0073e9SAndroid Build Coastguard Worker mul_events[e.thread] = e.time_range.elapsed_us() 397*da0073e9SAndroid Build Coastguard Worker elif e.name == "other_fn": 398*da0073e9SAndroid Build Coastguard Worker self.assertTrue(e.thread not in other_fn_events) 399*da0073e9SAndroid Build Coastguard Worker other_fn_events[e.thread] = e.time_range.elapsed_us() 400*da0073e9SAndroid Build Coastguard Worker 401*da0073e9SAndroid Build Coastguard Worker self.assertTrue(len(mul_events) == 2) 402*da0073e9SAndroid Build Coastguard Worker self.assertTrue(len(other_fn_events) == 2) 403*da0073e9SAndroid Build Coastguard Worker 404*da0073e9SAndroid Build Coastguard Worker for thread, mul_time in mul_events.items(): 405*da0073e9SAndroid Build Coastguard Worker self.assertTrue(thread in other_fn_events) 406*da0073e9SAndroid Build Coastguard Worker self.assertTrue(other_fn_events[thread] >= mul_time) 407*da0073e9SAndroid Build Coastguard Worker 408*da0073e9SAndroid Build Coastguard Worker 409*da0073e9SAndroid Build Coastguard Workerclass TestJit(JitTestCase): 410*da0073e9SAndroid Build Coastguard Worker @unittest.skip("Requires a lot of RAM") 411*da0073e9SAndroid Build Coastguard Worker def test_big(self): 412*da0073e9SAndroid Build Coastguard Worker m = torch.jit.ScriptModule() 413*da0073e9SAndroid Build Coastguard Worker gig = int(1024 * 1024 * 1024 / 4) 414*da0073e9SAndroid Build Coastguard Worker # a small tensor in the first 4GB 415*da0073e9SAndroid Build Coastguard Worker m.v0 = nn.Parameter(torch.full((2,), 1, dtype=torch.float)) 416*da0073e9SAndroid Build Coastguard Worker # a large tensor in the first 4GB that ends outside of it 417*da0073e9SAndroid Build Coastguard Worker m.v1 = nn.Parameter(torch.full((5, gig), 2, dtype=torch.float)) 418*da0073e9SAndroid Build Coastguard Worker # a small tensor in >4GB space 419*da0073e9SAndroid Build Coastguard Worker m.v2 = nn.Parameter(torch.full((2,), 3, dtype=torch.float)) 420*da0073e9SAndroid Build Coastguard Worker # s large tensor in the > 4GB space 421*da0073e9SAndroid Build Coastguard Worker m.v3 = nn.Parameter(torch.full((5, gig), 4, dtype=torch.float)) 422*da0073e9SAndroid Build Coastguard Worker 423*da0073e9SAndroid Build Coastguard Worker m2 = self.getExportImportCopy(m) 424*da0073e9SAndroid Build Coastguard Worker 425*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tuple(m.parameters()), tuple(m2.parameters())) 426*da0073e9SAndroid Build Coastguard Worker 427*da0073e9SAndroid Build Coastguard Worker def test_inferred_as_tensor(self): 428*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Inferred the value for argument 'dim' to be of type 'Tensor' " 429*da0073e9SAndroid Build Coastguard Worker "because it was not annotated with an explicit type"): 430*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 431*da0073e9SAndroid Build Coastguard Worker def dot(points, query, dim): 432*da0073e9SAndroid Build Coastguard Worker return (points * query).sum(dim) 433*da0073e9SAndroid Build Coastguard Worker 434*da0073e9SAndroid Build Coastguard Worker def test_constants_pkl(self): 435*da0073e9SAndroid Build Coastguard Worker # This test asserts that the serialization archive includes a `constants.pkl` 436*da0073e9SAndroid Build Coastguard Worker # file. This file is used by `torch.load` to determine whether a zip file 437*da0073e9SAndroid Build Coastguard Worker # is a normal eager-mode serialization zip or a jit serialization zip. If 438*da0073e9SAndroid Build Coastguard Worker # you are deleting `constants.pkl`, make sure to update `torch.serialization.load` 439*da0073e9SAndroid Build Coastguard Worker # so it is still able to figure out which is which. 440*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 441*da0073e9SAndroid Build Coastguard Worker def fn(x): 442*da0073e9SAndroid Build Coastguard Worker return x 443*da0073e9SAndroid Build Coastguard Worker 444*da0073e9SAndroid Build Coastguard Worker buf = io.BytesIO() 445*da0073e9SAndroid Build Coastguard Worker torch.jit.save(fn, buf) 446*da0073e9SAndroid Build Coastguard Worker buf.seek(0) 447*da0073e9SAndroid Build Coastguard Worker 448*da0073e9SAndroid Build Coastguard Worker files = zipfile.ZipFile(buf).filelist 449*da0073e9SAndroid Build Coastguard Worker self.assertTrue(any('archive/constants.pkl' == f.filename for f in files)) 450*da0073e9SAndroid Build Coastguard Worker 451*da0073e9SAndroid Build Coastguard Worker def test_script_fn_pkl(self): 452*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(pickle.PickleError, "ScriptFunction cannot be pickled"): 453*da0073e9SAndroid Build Coastguard Worker 454*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 455*da0073e9SAndroid Build Coastguard Worker def fn(x: torch.Tensor) -> torch.Tensor: 456*da0073e9SAndroid Build Coastguard Worker return x 457*da0073e9SAndroid Build Coastguard Worker 458*da0073e9SAndroid Build Coastguard Worker pkl_fn = pickle.dumps(fn, protocol=0) 459*da0073e9SAndroid Build Coastguard Worker 460*da0073e9SAndroid Build Coastguard Worker def test_restore_device(self): 461*da0073e9SAndroid Build Coastguard Worker class M(torch.jit.ScriptModule): 462*da0073e9SAndroid Build Coastguard Worker def __init__(self, cpu_device_str): 463*da0073e9SAndroid Build Coastguard Worker super().__init__() 464*da0073e9SAndroid Build Coastguard Worker self.p0 = nn.Parameter(torch.tensor([0.3], dtype=torch.float, 465*da0073e9SAndroid Build Coastguard Worker device=cpu_device_str)) 466*da0073e9SAndroid Build Coastguard Worker self.b0 = torch.tensor([0.9], dtype=torch.float, 467*da0073e9SAndroid Build Coastguard Worker device=cpu_device_str) 468*da0073e9SAndroid Build Coastguard Worker 469*da0073e9SAndroid Build Coastguard Worker # main purpose is checking map_location works 470*da0073e9SAndroid Build Coastguard Worker m = M("cpu") 471*da0073e9SAndroid Build Coastguard Worker m2 = self.getExportImportCopy(m) 472*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tuple(m.parameters()), tuple(m2.parameters())) 473*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tuple(m.buffers()), tuple(m2.buffers())) 474*da0073e9SAndroid Build Coastguard Worker self.assertFalse(m2.p0.is_cuda) 475*da0073e9SAndroid Build Coastguard Worker self.assertFalse(m2.b0.is_cuda) 476*da0073e9SAndroid Build Coastguard Worker 477*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not RUN_CUDA, "restore device requires CUDA") 478*da0073e9SAndroid Build Coastguard Worker def test_restore_device_cuda(self): 479*da0073e9SAndroid Build Coastguard Worker class MyModule(torch.jit.ScriptModule): 480*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 481*da0073e9SAndroid Build Coastguard Worker super().__init__() 482*da0073e9SAndroid Build Coastguard Worker self.b0 = nn.Buffer(torch.randn(1, 3)) 483*da0073e9SAndroid Build Coastguard Worker self.p0 = nn.Parameter(torch.randn(2, 3)) 484*da0073e9SAndroid Build Coastguard Worker 485*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 486*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 487*da0073e9SAndroid Build Coastguard Worker return x + self.b0 + self.p0 488*da0073e9SAndroid Build Coastguard Worker 489*da0073e9SAndroid Build Coastguard Worker m = MyModule() 490*da0073e9SAndroid Build Coastguard Worker m.cuda(torch.cuda.device_count() - 1) 491*da0073e9SAndroid Build Coastguard Worker cuda_device_str = 'cuda:' + str(torch.cuda.device_count() - 1) 492*da0073e9SAndroid Build Coastguard Worker 493*da0073e9SAndroid Build Coastguard Worker self.assertTrue(m.p0.is_cuda) 494*da0073e9SAndroid Build Coastguard Worker self.assertTrue(m.b0.is_cuda) 495*da0073e9SAndroid Build Coastguard Worker 496*da0073e9SAndroid Build Coastguard Worker # restore to the saved devices 497*da0073e9SAndroid Build Coastguard Worker m2 = self.getExportImportCopy(m) 498*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tuple(m.parameters()), tuple(m2.parameters())) 499*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tuple(m.buffers()), tuple(m2.buffers())) 500*da0073e9SAndroid Build Coastguard Worker self.assertEqual(str(m2.p0.device), cuda_device_str) 501*da0073e9SAndroid Build Coastguard Worker self.assertEqual(str(m2.b0.device), cuda_device_str) 502*da0073e9SAndroid Build Coastguard Worker 503*da0073e9SAndroid Build Coastguard Worker # restore all to cpu using string 504*da0073e9SAndroid Build Coastguard Worker cpu_device_str = 'cpu' 505*da0073e9SAndroid Build Coastguard Worker m3 = self.getExportImportCopy(m, map_location=cpu_device_str) 506*da0073e9SAndroid Build Coastguard Worker self.assertEqual(str(m3.p0.device), cpu_device_str) 507*da0073e9SAndroid Build Coastguard Worker self.assertEqual(str(m3.b0.device), cpu_device_str) 508*da0073e9SAndroid Build Coastguard Worker 509*da0073e9SAndroid Build Coastguard Worker # restore all to first gpu using device 510*da0073e9SAndroid Build Coastguard Worker m4 = self.getExportImportCopy( 511*da0073e9SAndroid Build Coastguard Worker m3, map_location=torch.device('cuda:0')) 512*da0073e9SAndroid Build Coastguard Worker self.assertEqual(str(m4.p0.device), 'cuda:0') 513*da0073e9SAndroid Build Coastguard Worker self.assertEqual(str(m4.b0.device), 'cuda:0') 514*da0073e9SAndroid Build Coastguard Worker 515*da0073e9SAndroid Build Coastguard Worker # compute and compare the results 516*da0073e9SAndroid Build Coastguard Worker input = torch.rand(2, 3).cuda(torch.cuda.device_count() - 1) 517*da0073e9SAndroid Build Coastguard Worker origin_result = m(input) 518*da0073e9SAndroid Build Coastguard Worker self.assertEqual(origin_result, m2(input)) 519*da0073e9SAndroid Build Coastguard Worker self.assertEqual(origin_result, m3(input.cpu())) 520*da0073e9SAndroid Build Coastguard Worker self.assertEqual(origin_result, m4(input.cuda(0))) 521*da0073e9SAndroid Build Coastguard Worker 522*da0073e9SAndroid Build Coastguard Worker def test_trace_retains_train(self): 523*da0073e9SAndroid Build Coastguard Worker class M(torch.nn.Module): 524*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 525*da0073e9SAndroid Build Coastguard Worker return x 526*da0073e9SAndroid Build Coastguard Worker m = M() 527*da0073e9SAndroid Build Coastguard Worker m.eval() 528*da0073e9SAndroid Build Coastguard Worker tm = torch.jit.trace(m, (torch.rand(3))) 529*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tm.training, m.training) 530*da0073e9SAndroid Build Coastguard Worker 531*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not RUN_CUDA, "restore device requires CUDA") 532*da0073e9SAndroid Build Coastguard Worker def test_restore_shared_storage_on_cuda(self): 533*da0073e9SAndroid Build Coastguard Worker class Foo(torch.jit.ScriptModule): 534*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 535*da0073e9SAndroid Build Coastguard Worker super().__init__() 536*da0073e9SAndroid Build Coastguard Worker whole_tensor = torch.randn(4, 5, dtype=torch.float, device='cpu') 537*da0073e9SAndroid Build Coastguard Worker self.p0 = nn.Parameter(whole_tensor.narrow(0, 0, 1)) 538*da0073e9SAndroid Build Coastguard Worker self.b0 = nn.Buffer(whole_tensor.narrow(0, 3, 1)) 539*da0073e9SAndroid Build Coastguard Worker 540*da0073e9SAndroid Build Coastguard Worker m = Foo() 541*da0073e9SAndroid Build Coastguard Worker m2 = self.getExportImportCopy(m, map_location=torch.device('cuda:0')) 542*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tuple(m.parameters()), tuple(m2.parameters())) 543*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tuple(m.buffers()), tuple(m2.buffers())) 544*da0073e9SAndroid Build Coastguard Worker self.assertTrue(m2.p0.is_cuda) 545*da0073e9SAndroid Build Coastguard Worker self.assertTrue(m2.b0.is_cuda) 546*da0073e9SAndroid Build Coastguard Worker self.assertTrue(m2.p0.is_shared()) 547*da0073e9SAndroid Build Coastguard Worker self.assertTrue(m2.b0.is_shared()) 548*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m2.b0.storage().data_ptr(), m2.p0.storage().data_ptr()) 549*da0073e9SAndroid Build Coastguard Worker 550*da0073e9SAndroid Build Coastguard Worker def test_add_relu_fusion(self): 551*da0073e9SAndroid Build Coastguard Worker class M(torch.nn.Module): 552*da0073e9SAndroid Build Coastguard Worker def __init__(self, relu_op): 553*da0073e9SAndroid Build Coastguard Worker super().__init__() 554*da0073e9SAndroid Build Coastguard Worker self.relu_op = relu_op 555*da0073e9SAndroid Build Coastguard Worker 556*da0073e9SAndroid Build Coastguard Worker def forward(self, a, b, c): 557*da0073e9SAndroid Build Coastguard Worker tmp = torch.add(a, b) 558*da0073e9SAndroid Build Coastguard Worker x = self.relu_op(tmp) 559*da0073e9SAndroid Build Coastguard Worker d = torch.add(a, c) 560*da0073e9SAndroid Build Coastguard Worker return x + d 561*da0073e9SAndroid Build Coastguard Worker a = torch.rand((7, 11)) 562*da0073e9SAndroid Build Coastguard Worker a = a * -10 563*da0073e9SAndroid Build Coastguard Worker a = a + 5 564*da0073e9SAndroid Build Coastguard Worker b = torch.rand((7, 11)) 565*da0073e9SAndroid Build Coastguard Worker c = torch.rand((7, 11)) 566*da0073e9SAndroid Build Coastguard Worker m = torch.jit.script(M(torch.relu)) 567*da0073e9SAndroid Build Coastguard Worker orig_res = m(a, b, c) 568*da0073e9SAndroid Build Coastguard Worker torch._C._jit_pass_fuse_add_relu(m.graph) 569*da0073e9SAndroid Build Coastguard Worker buffer = io.BytesIO() 570*da0073e9SAndroid Build Coastguard Worker torch.jit.save(m, buffer) 571*da0073e9SAndroid Build Coastguard Worker buffer.seek(0) 572*da0073e9SAndroid Build Coastguard Worker m = torch.jit.load(buffer) 573*da0073e9SAndroid Build Coastguard Worker new_res = m(a, b, c) 574*da0073e9SAndroid Build Coastguard Worker FileCheck().check_not("aten::relu(") \ 575*da0073e9SAndroid Build Coastguard Worker .check("aten::_add_relu(") \ 576*da0073e9SAndroid Build Coastguard Worker .run(m.graph) 577*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(orig_res, new_res) 578*da0073e9SAndroid Build Coastguard Worker 579*da0073e9SAndroid Build Coastguard Worker # add, relu_ 580*da0073e9SAndroid Build Coastguard Worker a = torch.rand((7, 11)) 581*da0073e9SAndroid Build Coastguard Worker a = a * -10 582*da0073e9SAndroid Build Coastguard Worker a = a + 5 583*da0073e9SAndroid Build Coastguard Worker b = torch.rand((7, 11)) 584*da0073e9SAndroid Build Coastguard Worker c = torch.rand((7, 11)) 585*da0073e9SAndroid Build Coastguard Worker m = torch.jit.script(M(torch.relu_)) 586*da0073e9SAndroid Build Coastguard Worker orig_res = m(a, b, c) 587*da0073e9SAndroid Build Coastguard Worker torch._C._jit_pass_fuse_add_relu(m.graph) 588*da0073e9SAndroid Build Coastguard Worker buffer = io.BytesIO() 589*da0073e9SAndroid Build Coastguard Worker torch.jit.save(m, buffer) 590*da0073e9SAndroid Build Coastguard Worker buffer.seek(0) 591*da0073e9SAndroid Build Coastguard Worker m = torch.jit.load(buffer) 592*da0073e9SAndroid Build Coastguard Worker new_res = m(a, b, c) 593*da0073e9SAndroid Build Coastguard Worker FileCheck().check_not("aten::relu_(") \ 594*da0073e9SAndroid Build Coastguard Worker .check("aten::_add_relu(") \ 595*da0073e9SAndroid Build Coastguard Worker .run(m.graph) 596*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(orig_res, new_res) 597*da0073e9SAndroid Build Coastguard Worker 598*da0073e9SAndroid Build Coastguard Worker class Madd_(torch.nn.Module): 599*da0073e9SAndroid Build Coastguard Worker def __init__(self, relu_op): 600*da0073e9SAndroid Build Coastguard Worker super().__init__() 601*da0073e9SAndroid Build Coastguard Worker self.relu_op = relu_op 602*da0073e9SAndroid Build Coastguard Worker 603*da0073e9SAndroid Build Coastguard Worker def forward(self, a, b): 604*da0073e9SAndroid Build Coastguard Worker x = a.add_(b) 605*da0073e9SAndroid Build Coastguard Worker x = self.relu_op(x) 606*da0073e9SAndroid Build Coastguard Worker return x 607*da0073e9SAndroid Build Coastguard Worker 608*da0073e9SAndroid Build Coastguard Worker # add_, relu_ 609*da0073e9SAndroid Build Coastguard Worker a = torch.rand((7, 11)) 610*da0073e9SAndroid Build Coastguard Worker a = a * -10 611*da0073e9SAndroid Build Coastguard Worker a = a + 5 612*da0073e9SAndroid Build Coastguard Worker b = torch.rand((7, 11)) 613*da0073e9SAndroid Build Coastguard Worker # Because in place add_ will overwrite a 614*da0073e9SAndroid Build Coastguard Worker a_copy = a.clone() 615*da0073e9SAndroid Build Coastguard Worker m = torch.jit.script(Madd_(torch.relu_)) 616*da0073e9SAndroid Build Coastguard Worker orig_res = m(a, b) 617*da0073e9SAndroid Build Coastguard Worker torch._C._jit_pass_fuse_add_relu(m.graph) 618*da0073e9SAndroid Build Coastguard Worker buffer = io.BytesIO() 619*da0073e9SAndroid Build Coastguard Worker torch.jit.save(m, buffer) 620*da0073e9SAndroid Build Coastguard Worker buffer.seek(0) 621*da0073e9SAndroid Build Coastguard Worker m = torch.jit.load(buffer) 622*da0073e9SAndroid Build Coastguard Worker new_res = m(a_copy, b) 623*da0073e9SAndroid Build Coastguard Worker FileCheck().check_not("aten::add_(") \ 624*da0073e9SAndroid Build Coastguard Worker .check_not("aten::relu_(") \ 625*da0073e9SAndroid Build Coastguard Worker .check("aten::_add_relu_(") \ 626*da0073e9SAndroid Build Coastguard Worker .run(m.graph) 627*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(orig_res, new_res) 628*da0073e9SAndroid Build Coastguard Worker # Since _add_relu_ does inplace mutation ensure 629*da0073e9SAndroid Build Coastguard Worker # a_copy is modified 630*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(orig_res, a_copy) 631*da0073e9SAndroid Build Coastguard Worker 632*da0073e9SAndroid Build Coastguard Worker class Madd_out(torch.nn.Module): 633*da0073e9SAndroid Build Coastguard Worker def __init__(self, relu_op): 634*da0073e9SAndroid Build Coastguard Worker super().__init__() 635*da0073e9SAndroid Build Coastguard Worker self.relu_op = relu_op 636*da0073e9SAndroid Build Coastguard Worker 637*da0073e9SAndroid Build Coastguard Worker def forward(self, a, b): 638*da0073e9SAndroid Build Coastguard Worker x = torch.add(a, b, out=a) 639*da0073e9SAndroid Build Coastguard Worker x = self.relu_op(x) 640*da0073e9SAndroid Build Coastguard Worker return x 641*da0073e9SAndroid Build Coastguard Worker a = torch.rand((7, 11)) 642*da0073e9SAndroid Build Coastguard Worker a = a * -10 643*da0073e9SAndroid Build Coastguard Worker a = a + 5 644*da0073e9SAndroid Build Coastguard Worker b = torch.rand((7, 11)) 645*da0073e9SAndroid Build Coastguard Worker 646*da0073e9SAndroid Build Coastguard Worker # add_out, relu_ 647*da0073e9SAndroid Build Coastguard Worker a = torch.rand((7, 11)) 648*da0073e9SAndroid Build Coastguard Worker a = a * -10 649*da0073e9SAndroid Build Coastguard Worker a = a + 5 650*da0073e9SAndroid Build Coastguard Worker b = torch.rand((7, 11)) 651*da0073e9SAndroid Build Coastguard Worker # Because in place add_ will overwrite a 652*da0073e9SAndroid Build Coastguard Worker a_copy = a.clone() 653*da0073e9SAndroid Build Coastguard Worker m = torch.jit.script(Madd_out(torch.relu_)) 654*da0073e9SAndroid Build Coastguard Worker orig_res = m(a, b) 655*da0073e9SAndroid Build Coastguard Worker torch._C._jit_pass_fuse_add_relu(m.graph) 656*da0073e9SAndroid Build Coastguard Worker buffer = io.BytesIO() 657*da0073e9SAndroid Build Coastguard Worker torch.jit.save(m, buffer) 658*da0073e9SAndroid Build Coastguard Worker buffer.seek(0) 659*da0073e9SAndroid Build Coastguard Worker m = torch.jit.load(buffer) 660*da0073e9SAndroid Build Coastguard Worker new_res = m(a_copy, b) 661*da0073e9SAndroid Build Coastguard Worker FileCheck().check_not("aten::add(") \ 662*da0073e9SAndroid Build Coastguard Worker .check_not("aten::relu_(") \ 663*da0073e9SAndroid Build Coastguard Worker .check("aten::_add_relu(") \ 664*da0073e9SAndroid Build Coastguard Worker .run(m.graph) 665*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(orig_res, new_res) 666*da0073e9SAndroid Build Coastguard Worker # Since _add_relu_ with out=a does inplace mutation ensure 667*da0073e9SAndroid Build Coastguard Worker # a_copy is modified 668*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(orig_res, a_copy) 669*da0073e9SAndroid Build Coastguard Worker 670*da0073e9SAndroid Build Coastguard Worker def test_repeat_interleave_script(self): 671*da0073e9SAndroid Build Coastguard Worker def fn(input: torch.Tensor, repeats: torch.Tensor) -> torch.Tensor: 672*da0073e9SAndroid Build Coastguard Worker output = input.repeat_interleave(repeats) 673*da0073e9SAndroid Build Coastguard Worker return output 674*da0073e9SAndroid Build Coastguard Worker fn_scripted = torch.jit.script(fn) 675*da0073e9SAndroid Build Coastguard Worker 676*da0073e9SAndroid Build Coastguard Worker input = torch.tensor([5, 7], dtype=torch.int64) 677*da0073e9SAndroid Build Coastguard Worker repeats = torch.tensor([3, 6], dtype=torch.int64) 678*da0073e9SAndroid Build Coastguard Worker 679*da0073e9SAndroid Build Coastguard Worker output = fn(input, repeats) 680*da0073e9SAndroid Build Coastguard Worker output_scripted = fn_scripted(input, repeats) 681*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output_scripted, output) 682*da0073e9SAndroid Build Coastguard Worker 683*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "Simple executor doesn't have shape information") 684*da0073e9SAndroid Build Coastguard Worker def test_peephole_optimize_shape_ops(self): 685*da0073e9SAndroid Build Coastguard Worker def test_input(func, input, result): 686*da0073e9SAndroid Build Coastguard Worker # if result == 2 we will trigger a bailout and 687*da0073e9SAndroid Build Coastguard Worker # the unprofiled graph should return the correct result 688*da0073e9SAndroid Build Coastguard Worker self.assertEqual(func(input, profile_and_replay=True), result) 689*da0073e9SAndroid Build Coastguard Worker gre = func.graph_for(input) 690*da0073e9SAndroid Build Coastguard Worker FileCheck().check_not("prim::If").run(gre) 691*da0073e9SAndroid Build Coastguard Worker 692*da0073e9SAndroid Build Coastguard Worker def test_dim(): 693*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 694*da0073e9SAndroid Build Coastguard Worker def func(x): 695*da0073e9SAndroid Build Coastguard Worker if x.dim() == 1: 696*da0073e9SAndroid Build Coastguard Worker return 1 697*da0073e9SAndroid Build Coastguard Worker else: 698*da0073e9SAndroid Build Coastguard Worker return 2 699*da0073e9SAndroid Build Coastguard Worker 700*da0073e9SAndroid Build Coastguard Worker test_input(func, torch.tensor([0.5]), 1) 701*da0073e9SAndroid Build Coastguard Worker test_input(func, torch.tensor([[0.5]]), 2) 702*da0073e9SAndroid Build Coastguard Worker test_dim() 703*da0073e9SAndroid Build Coastguard Worker 704*da0073e9SAndroid Build Coastguard Worker def test_size_index(): 705*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 706*da0073e9SAndroid Build Coastguard Worker def func(x): 707*da0073e9SAndroid Build Coastguard Worker if x.size(0) == 1: 708*da0073e9SAndroid Build Coastguard Worker return 1 709*da0073e9SAndroid Build Coastguard Worker else: 710*da0073e9SAndroid Build Coastguard Worker return 2 711*da0073e9SAndroid Build Coastguard Worker 712*da0073e9SAndroid Build Coastguard Worker test_input(func, torch.rand([1, 2]), 1) 713*da0073e9SAndroid Build Coastguard Worker test_input(func, torch.rand([1, 3]), 1) 714*da0073e9SAndroid Build Coastguard Worker 715*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 716*da0073e9SAndroid Build Coastguard Worker def neg_index(x): 717*da0073e9SAndroid Build Coastguard Worker if x.size(-2) == 1: 718*da0073e9SAndroid Build Coastguard Worker return 1 719*da0073e9SAndroid Build Coastguard Worker else: 720*da0073e9SAndroid Build Coastguard Worker return 2 721*da0073e9SAndroid Build Coastguard Worker 722*da0073e9SAndroid Build Coastguard Worker test_input(neg_index, torch.rand([1, 2]), 1) 723*da0073e9SAndroid Build Coastguard Worker test_input(neg_index, torch.rand([1, 3]), 1) 724*da0073e9SAndroid Build Coastguard Worker 725*da0073e9SAndroid Build Coastguard Worker if GRAPH_EXECUTOR == ProfilingMode.PROFILING: 726*da0073e9SAndroid Build Coastguard Worker test_size_index() 727*da0073e9SAndroid Build Coastguard Worker 728*da0073e9SAndroid Build Coastguard Worker def test_dtype(): 729*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 730*da0073e9SAndroid Build Coastguard Worker def func(x): 731*da0073e9SAndroid Build Coastguard Worker if x.dtype == torch.float32: 732*da0073e9SAndroid Build Coastguard Worker return 1 733*da0073e9SAndroid Build Coastguard Worker else: 734*da0073e9SAndroid Build Coastguard Worker return 2 735*da0073e9SAndroid Build Coastguard Worker 736*da0073e9SAndroid Build Coastguard Worker test_input(func, torch.tensor(0.5, dtype=torch.float32), 1) 737*da0073e9SAndroid Build Coastguard Worker test_input(func, torch.tensor(0.5, dtype=torch.int64), 2) 738*da0073e9SAndroid Build Coastguard Worker test_dtype() 739*da0073e9SAndroid Build Coastguard Worker 740*da0073e9SAndroid Build Coastguard Worker def test_is_floating_poiint(): 741*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 742*da0073e9SAndroid Build Coastguard Worker def func(x): 743*da0073e9SAndroid Build Coastguard Worker if x.is_floating_point(): 744*da0073e9SAndroid Build Coastguard Worker return 1 745*da0073e9SAndroid Build Coastguard Worker else: 746*da0073e9SAndroid Build Coastguard Worker return 2 747*da0073e9SAndroid Build Coastguard Worker 748*da0073e9SAndroid Build Coastguard Worker test_input(func, torch.tensor(0.5, dtype=torch.float32), 1) 749*da0073e9SAndroid Build Coastguard Worker test_input(func, torch.tensor(0.5, dtype=torch.int64), 2) 750*da0073e9SAndroid Build Coastguard Worker test_is_floating_poiint() 751*da0073e9SAndroid Build Coastguard Worker 752*da0073e9SAndroid Build Coastguard Worker def test_device(): 753*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 754*da0073e9SAndroid Build Coastguard Worker def func_1(x): 755*da0073e9SAndroid Build Coastguard Worker if x.device == torch.device('cuda:0'): 756*da0073e9SAndroid Build Coastguard Worker a = 0 757*da0073e9SAndroid Build Coastguard Worker else: 758*da0073e9SAndroid Build Coastguard Worker a = 1 759*da0073e9SAndroid Build Coastguard Worker return a 760*da0073e9SAndroid Build Coastguard Worker 761*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 762*da0073e9SAndroid Build Coastguard Worker def func_2(x): 763*da0073e9SAndroid Build Coastguard Worker if x.is_cuda: 764*da0073e9SAndroid Build Coastguard Worker a = 0 765*da0073e9SAndroid Build Coastguard Worker else: 766*da0073e9SAndroid Build Coastguard Worker a = 1 767*da0073e9SAndroid Build Coastguard Worker return a 768*da0073e9SAndroid Build Coastguard Worker 769*da0073e9SAndroid Build Coastguard Worker test_input(func_1, torch.tensor(0.5), 1) 770*da0073e9SAndroid Build Coastguard Worker test_input(func_2, torch.tensor(0.5), 1) 771*da0073e9SAndroid Build Coastguard Worker 772*da0073e9SAndroid Build Coastguard Worker if RUN_CUDA: 773*da0073e9SAndroid Build Coastguard Worker test_input(func_1, torch.tensor(0.5, device="cuda:0"), 0) 774*da0073e9SAndroid Build Coastguard Worker test_input(func_2, torch.tensor(0.5, device="cuda:0"), 0) 775*da0073e9SAndroid Build Coastguard Worker 776*da0073e9SAndroid Build Coastguard Worker test_device() 777*da0073e9SAndroid Build Coastguard Worker 778*da0073e9SAndroid Build Coastguard Worker def test_attrs(self): 779*da0073e9SAndroid Build Coastguard Worker def foo(x): 780*da0073e9SAndroid Build Coastguard Worker return ( 781*da0073e9SAndroid Build Coastguard Worker # x.dtype, TODO: dtype long -> instance conversion 782*da0073e9SAndroid Build Coastguard Worker x.device, 783*da0073e9SAndroid Build Coastguard Worker x.shape, 784*da0073e9SAndroid Build Coastguard Worker x.is_cuda, 785*da0073e9SAndroid Build Coastguard Worker x.is_mkldnn, 786*da0073e9SAndroid Build Coastguard Worker x.is_quantized, 787*da0073e9SAndroid Build Coastguard Worker x.requires_grad, 788*da0073e9SAndroid Build Coastguard Worker x.T, 789*da0073e9SAndroid Build Coastguard Worker x.mT, 790*da0073e9SAndroid Build Coastguard Worker x.H, 791*da0073e9SAndroid Build Coastguard Worker x.mH 792*da0073e9SAndroid Build Coastguard Worker # x.layout TODO: layout long -> instance conversion 793*da0073e9SAndroid Build Coastguard Worker ) 794*da0073e9SAndroid Build Coastguard Worker 795*da0073e9SAndroid Build Coastguard Worker scripted = torch.jit.script(foo) 796*da0073e9SAndroid Build Coastguard Worker x = torch.rand(3, 4) 797*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scripted(x), foo(x)) 798*da0073e9SAndroid Build Coastguard Worker 799*da0073e9SAndroid Build Coastguard Worker def test_layout(self): 800*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 801*da0073e9SAndroid Build Coastguard Worker def check(x, y): 802*da0073e9SAndroid Build Coastguard Worker return x.layout == y.layout 803*da0073e9SAndroid Build Coastguard Worker 804*da0073e9SAndroid Build Coastguard Worker x = torch.rand(3, 4) 805*da0073e9SAndroid Build Coastguard Worker y = torch.rand(3, 4) 806*da0073e9SAndroid Build Coastguard Worker 807*da0073e9SAndroid Build Coastguard Worker self.assertTrue(check(x, y)) 808*da0073e9SAndroid Build Coastguard Worker 809*da0073e9SAndroid Build Coastguard Worker def test_matrix_transpose(self): 810*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 811*da0073e9SAndroid Build Coastguard Worker def check(x): 812*da0073e9SAndroid Build Coastguard Worker return torch.equal(x.mT, x.transpose(-2, -1)) 813*da0073e9SAndroid Build Coastguard Worker 814*da0073e9SAndroid Build Coastguard Worker x = torch.rand(3, 4) 815*da0073e9SAndroid Build Coastguard Worker self.assertTrue(check(x)) 816*da0073e9SAndroid Build Coastguard Worker 817*da0073e9SAndroid Build Coastguard Worker def test_transpose(self): 818*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 819*da0073e9SAndroid Build Coastguard Worker def check(x): 820*da0073e9SAndroid Build Coastguard Worker return torch.equal(x.T, x.t()) 821*da0073e9SAndroid Build Coastguard Worker 822*da0073e9SAndroid Build Coastguard Worker x = torch.rand(3, 4) 823*da0073e9SAndroid Build Coastguard Worker self.assertTrue(check(x)) 824*da0073e9SAndroid Build Coastguard Worker 825*da0073e9SAndroid Build Coastguard Worker def test_matrix_conj_transpose(self): 826*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 827*da0073e9SAndroid Build Coastguard Worker def check(x): 828*da0073e9SAndroid Build Coastguard Worker return torch.equal(x.mH, x.transpose(-2, -1).conj()) 829*da0073e9SAndroid Build Coastguard Worker 830*da0073e9SAndroid Build Coastguard Worker x = torch.rand(3, 4) 831*da0073e9SAndroid Build Coastguard Worker self.assertTrue(check(x)) 832*da0073e9SAndroid Build Coastguard Worker 833*da0073e9SAndroid Build Coastguard Worker x = make_tensor((3, 4), device="cpu", dtype=torch.complex64) 834*da0073e9SAndroid Build Coastguard Worker self.assertTrue(check(x)) 835*da0073e9SAndroid Build Coastguard Worker 836*da0073e9SAndroid Build Coastguard Worker def test_conj_transpose(self): 837*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 838*da0073e9SAndroid Build Coastguard Worker def check(x): 839*da0073e9SAndroid Build Coastguard Worker return torch.equal(x.H, x.t().conj()) 840*da0073e9SAndroid Build Coastguard Worker 841*da0073e9SAndroid Build Coastguard Worker x = torch.rand(3, 4) 842*da0073e9SAndroid Build Coastguard Worker self.assertTrue(check(x)) 843*da0073e9SAndroid Build Coastguard Worker 844*da0073e9SAndroid Build Coastguard Worker x = make_tensor((3, 4), device="cpu", dtype=torch.complex64) 845*da0073e9SAndroid Build Coastguard Worker self.assertTrue(check(x)) 846*da0073e9SAndroid Build Coastguard Worker 847*da0073e9SAndroid Build Coastguard Worker def test_T_mT_H_mH(self): 848*da0073e9SAndroid Build Coastguard Worker def T(x): 849*da0073e9SAndroid Build Coastguard Worker return x.mT 850*da0073e9SAndroid Build Coastguard Worker 851*da0073e9SAndroid Build Coastguard Worker def mT(x): 852*da0073e9SAndroid Build Coastguard Worker return x.mT 853*da0073e9SAndroid Build Coastguard Worker 854*da0073e9SAndroid Build Coastguard Worker def H(x): 855*da0073e9SAndroid Build Coastguard Worker return x.H 856*da0073e9SAndroid Build Coastguard Worker 857*da0073e9SAndroid Build Coastguard Worker def mH(x): 858*da0073e9SAndroid Build Coastguard Worker return x.mH 859*da0073e9SAndroid Build Coastguard Worker 860*da0073e9SAndroid Build Coastguard Worker x = torch.rand(3, 4) 861*da0073e9SAndroid Build Coastguard Worker y = make_tensor((3, 4), device="cpu", dtype=torch.complex64) 862*da0073e9SAndroid Build Coastguard Worker 863*da0073e9SAndroid Build Coastguard Worker self.checkScript(T, (x, )) 864*da0073e9SAndroid Build Coastguard Worker self.checkScript(mT, (x, )) 865*da0073e9SAndroid Build Coastguard Worker self.checkScript(H, (x, )) 866*da0073e9SAndroid Build Coastguard Worker self.checkScript(mH, (x, )) 867*da0073e9SAndroid Build Coastguard Worker self.checkScript(T, (y, )) 868*da0073e9SAndroid Build Coastguard Worker self.checkScript(mT, (y, )) 869*da0073e9SAndroid Build Coastguard Worker self.checkScript(H, (y, )) 870*da0073e9SAndroid Build Coastguard Worker self.checkScript(mH, (y, )) 871*da0073e9SAndroid Build Coastguard Worker 872*da0073e9SAndroid Build Coastguard Worker def test_nn_conv(self): 873*da0073e9SAndroid Build Coastguard Worker class Mod(nn.Module): 874*da0073e9SAndroid Build Coastguard Worker def __init__(self, conv): 875*da0073e9SAndroid Build Coastguard Worker super().__init__() 876*da0073e9SAndroid Build Coastguard Worker self.conv = conv 877*da0073e9SAndroid Build Coastguard Worker 878*da0073e9SAndroid Build Coastguard Worker def forward(self, input): 879*da0073e9SAndroid Build Coastguard Worker return self.conv(input) 880*da0073e9SAndroid Build Coastguard Worker 881*da0073e9SAndroid Build Coastguard Worker inputs = [ 882*da0073e9SAndroid Build Coastguard Worker # Conv 883*da0073e9SAndroid Build Coastguard Worker (Mod(nn.Conv1d(16, 33, 3, stride=2)), torch.randn(20, 16, 5)), 884*da0073e9SAndroid Build Coastguard Worker (Mod(nn.Conv2d(16, 33, 3, stride=2)), torch.randn(20, 16, 5, 10)), 885*da0073e9SAndroid Build Coastguard Worker (Mod(nn.Conv3d(16, 33, 3, stride=2)), torch.randn(20, 16, 3, 5, 4)), 886*da0073e9SAndroid Build Coastguard Worker # ConvTransposed 887*da0073e9SAndroid Build Coastguard Worker (Mod(nn.ConvTranspose1d(16, 33, 3, stride=2)), torch.randn(20, 16, 5)), 888*da0073e9SAndroid Build Coastguard Worker (Mod(nn.ConvTranspose2d(16, 33, 3, stride=2)), torch.randn(20, 16, 5, 10)), 889*da0073e9SAndroid Build Coastguard Worker (Mod(nn.ConvTranspose3d(16, 33, 3, stride=2)), torch.randn(20, 16, 3, 5, 4)), 890*da0073e9SAndroid Build Coastguard Worker ] 891*da0073e9SAndroid Build Coastguard Worker 892*da0073e9SAndroid Build Coastguard Worker for m, inp in inputs: 893*da0073e9SAndroid Build Coastguard Worker self.checkModule(m, (inp,)) 894*da0073e9SAndroid Build Coastguard Worker 895*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, 'Not implemented for Simple or Legacy') 896*da0073e9SAndroid Build Coastguard Worker def test_debug_flush_compilation_cache(self): 897*da0073e9SAndroid Build Coastguard Worker def foo(x): 898*da0073e9SAndroid Build Coastguard Worker return x + 2 899*da0073e9SAndroid Build Coastguard Worker 900*da0073e9SAndroid Build Coastguard Worker class Mod(nn.Module): 901*da0073e9SAndroid Build Coastguard Worker def forward(self, t): 902*da0073e9SAndroid Build Coastguard Worker return t + 2 903*da0073e9SAndroid Build Coastguard Worker 904*da0073e9SAndroid Build Coastguard Worker m = torch.jit.script(Mod()) 905*da0073e9SAndroid Build Coastguard Worker x = torch.rand(1, 10) 906*da0073e9SAndroid Build Coastguard Worker 907*da0073e9SAndroid Build Coastguard Worker with enable_profiling_mode_for_profiling_tests(): 908*da0073e9SAndroid Build Coastguard Worker jitted = self.checkScript(foo, (x,)) 909*da0073e9SAndroid Build Coastguard Worker # shouldn't throw 910*da0073e9SAndroid Build Coastguard Worker states = jitted.get_debug_state() 911*da0073e9SAndroid Build Coastguard Worker 912*da0073e9SAndroid Build Coastguard Worker # after flushing there shouldn't be 913*da0073e9SAndroid Build Coastguard Worker # no opt plan 914*da0073e9SAndroid Build Coastguard Worker jitted._debug_flush_compilation_cache() 915*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "INTERNAL ASSERT FAILED"): 916*da0073e9SAndroid Build Coastguard Worker states = jitted.get_debug_state() 917*da0073e9SAndroid Build Coastguard Worker 918*da0073e9SAndroid Build Coastguard Worker NUM_RUNS = 1 919*da0073e9SAndroid Build Coastguard Worker with num_profiled_runs(NUM_RUNS): 920*da0073e9SAndroid Build Coastguard Worker m(x) 921*da0073e9SAndroid Build Coastguard Worker m(x) 922*da0073e9SAndroid Build Coastguard Worker fwd = m._c._get_method("forward") 923*da0073e9SAndroid Build Coastguard Worker states = m.get_debug_state() 924*da0073e9SAndroid Build Coastguard Worker 925*da0073e9SAndroid Build Coastguard Worker # after flushing there shouldn't be 926*da0073e9SAndroid Build Coastguard Worker # no opt plan 927*da0073e9SAndroid Build Coastguard Worker fwd._debug_flush_compilation_cache() 928*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "INTERNAL ASSERT FAILED"): 929*da0073e9SAndroid Build Coastguard Worker states = m.get_debug_state() 930*da0073e9SAndroid Build Coastguard Worker 931*da0073e9SAndroid Build Coastguard Worker def test_numel(self): 932*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 933*da0073e9SAndroid Build Coastguard Worker def get_numel_script(x): 934*da0073e9SAndroid Build Coastguard Worker return x.numel() 935*da0073e9SAndroid Build Coastguard Worker 936*da0073e9SAndroid Build Coastguard Worker x = torch.rand(3, 4) 937*da0073e9SAndroid Build Coastguard Worker numel = get_numel_script(x) 938*da0073e9SAndroid Build Coastguard Worker self.assertEqual(numel, x.numel()) 939*da0073e9SAndroid Build Coastguard Worker 940*da0073e9SAndroid Build Coastguard Worker def test_element_size(self): 941*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 942*da0073e9SAndroid Build Coastguard Worker def get_element_size_script(x): 943*da0073e9SAndroid Build Coastguard Worker return x.element_size() 944*da0073e9SAndroid Build Coastguard Worker 945*da0073e9SAndroid Build Coastguard Worker x = torch.rand(3, 4) 946*da0073e9SAndroid Build Coastguard Worker element_size = get_element_size_script(x) 947*da0073e9SAndroid Build Coastguard Worker self.assertEqual(element_size, x.element_size()) 948*da0073e9SAndroid Build Coastguard Worker 949*da0073e9SAndroid Build Coastguard Worker def test_Sequential(self): 950*da0073e9SAndroid Build Coastguard Worker class Seq(nn.Module): 951*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 952*da0073e9SAndroid Build Coastguard Worker super().__init__() 953*da0073e9SAndroid Build Coastguard Worker self.seq = nn.Sequential(nn.Linear(10, 20), nn.Linear(20, 30)) 954*da0073e9SAndroid Build Coastguard Worker 955*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 956*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 957*da0073e9SAndroid Build Coastguard Worker for l in self.seq: 958*da0073e9SAndroid Build Coastguard Worker x = l(x) 959*da0073e9SAndroid Build Coastguard Worker return x 960*da0073e9SAndroid Build Coastguard Worker 961*da0073e9SAndroid Build Coastguard Worker m = torch.jit.script(Seq()) 962*da0073e9SAndroid Build Coastguard Worker assert m.graph # ensure jit was able to compile 963*da0073e9SAndroid Build Coastguard Worker 964*da0073e9SAndroid Build Coastguard Worker def test_ModuleList(self): 965*da0073e9SAndroid Build Coastguard Worker class Mod(nn.Module): 966*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 967*da0073e9SAndroid Build Coastguard Worker super().__init__() 968*da0073e9SAndroid Build Coastguard Worker self.model = nn.ModuleList([nn.Linear(10, 10) for _ in range(10)]) 969*da0073e9SAndroid Build Coastguard Worker self.model += (nn.Linear(10, 20),) 970*da0073e9SAndroid Build Coastguard Worker self.model.append(nn.Linear(20, 30)) 971*da0073e9SAndroid Build Coastguard Worker self.model.extend([nn.Linear(30, 40), nn.Linear(40, 50)]) 972*da0073e9SAndroid Build Coastguard Worker 973*da0073e9SAndroid Build Coastguard Worker def forward(self, v): 974*da0073e9SAndroid Build Coastguard Worker for m in self.model: 975*da0073e9SAndroid Build Coastguard Worker v = m(v) 976*da0073e9SAndroid Build Coastguard Worker return v 977*da0073e9SAndroid Build Coastguard Worker 978*da0073e9SAndroid Build Coastguard Worker m = torch.jit.script(Mod()) 979*da0073e9SAndroid Build Coastguard Worker assert m.graph # ensure jit was able to compile 980*da0073e9SAndroid Build Coastguard Worker 981*da0073e9SAndroid Build Coastguard Worker def test_disabled(self): 982*da0073e9SAndroid Build Coastguard Worker torch.jit._state.disable() 983*da0073e9SAndroid Build Coastguard Worker try: 984*da0073e9SAndroid Build Coastguard Worker def f(x, y): 985*da0073e9SAndroid Build Coastguard Worker return x + y 986*da0073e9SAndroid Build Coastguard Worker 987*da0073e9SAndroid Build Coastguard Worker self.assertIs(torch.jit.trace(f, (torch.randn(2, 2), torch.randn(2, 2))), f) 988*da0073e9SAndroid Build Coastguard Worker self.assertIs(torch.jit.script(f), f) 989*da0073e9SAndroid Build Coastguard Worker 990*da0073e9SAndroid Build Coastguard Worker class MyModule(torch.jit.ScriptModule): 991*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 992*da0073e9SAndroid Build Coastguard Worker def method(self, x): 993*da0073e9SAndroid Build Coastguard Worker return x 994*da0073e9SAndroid Build Coastguard Worker 995*da0073e9SAndroid Build Coastguard Worker # XXX: Unfortunately ScriptModule won't simply become Module now, 996*da0073e9SAndroid Build Coastguard Worker # because that requires disabling the JIT at startup time, which 997*da0073e9SAndroid Build Coastguard Worker # we can't do in here. 998*da0073e9SAndroid Build Coastguard Worker # We need to or those two conditions to make it work with all versions of Python 999*da0073e9SAndroid Build Coastguard Worker self.assertTrue(inspect.ismethod(MyModule.method) or inspect.isfunction(MyModule.method)) 1000*da0073e9SAndroid Build Coastguard Worker finally: 1001*da0073e9SAndroid Build Coastguard Worker torch.jit._state.enable() 1002*da0073e9SAndroid Build Coastguard Worker 1003*da0073e9SAndroid Build Coastguard Worker def test_train_eval(self): 1004*da0073e9SAndroid Build Coastguard Worker class Sub(nn.Module): 1005*da0073e9SAndroid Build Coastguard Worker def forward(self, input): 1006*da0073e9SAndroid Build Coastguard Worker if self.training: 1007*da0073e9SAndroid Build Coastguard Worker return input 1008*da0073e9SAndroid Build Coastguard Worker else: 1009*da0073e9SAndroid Build Coastguard Worker return -input 1010*da0073e9SAndroid Build Coastguard Worker 1011*da0073e9SAndroid Build Coastguard Worker class MyModule(torch.jit.ScriptModule): 1012*da0073e9SAndroid Build Coastguard Worker def __init__(self, module): 1013*da0073e9SAndroid Build Coastguard Worker super().__init__() 1014*da0073e9SAndroid Build Coastguard Worker self.module = module 1015*da0073e9SAndroid Build Coastguard Worker 1016*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 1017*da0073e9SAndroid Build Coastguard Worker def forward(self, input): 1018*da0073e9SAndroid Build Coastguard Worker return self.module(input) + 1 1019*da0073e9SAndroid Build Coastguard Worker 1020*da0073e9SAndroid Build Coastguard Worker m = MyModule(Sub()) 1021*da0073e9SAndroid Build Coastguard Worker input = torch.rand(3, 4) 1022*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input + 1, m(input)) 1023*da0073e9SAndroid Build Coastguard Worker m.eval() 1024*da0073e9SAndroid Build Coastguard Worker self.assertEqual(-input + 1, m(input)) 1025*da0073e9SAndroid Build Coastguard Worker 1026*da0073e9SAndroid Build Coastguard Worker # test batchnorm and dropout train/eval 1027*da0073e9SAndroid Build Coastguard Worker input = torch.randn(6, 10) 1028*da0073e9SAndroid Build Coastguard Worker batchnorm = nn.BatchNorm1d(10) 1029*da0073e9SAndroid Build Coastguard Worker dropout = nn.Dropout(p=0.2) 1030*da0073e9SAndroid Build Coastguard Worker 1031*da0073e9SAndroid Build Coastguard Worker m_batchnorm = MyModule(batchnorm) 1032*da0073e9SAndroid Build Coastguard Worker self.assertEqual(batchnorm(input) + 1, m_batchnorm(input)) 1033*da0073e9SAndroid Build Coastguard Worker batchnorm.eval() 1034*da0073e9SAndroid Build Coastguard Worker m_batchnorm.eval() 1035*da0073e9SAndroid Build Coastguard Worker self.assertEqual(batchnorm(input) + 1, m_batchnorm(input)) 1036*da0073e9SAndroid Build Coastguard Worker 1037*da0073e9SAndroid Build Coastguard Worker m_dropout = MyModule(dropout) 1038*da0073e9SAndroid Build Coastguard Worker dropout.eval() 1039*da0073e9SAndroid Build Coastguard Worker m_dropout.eval() 1040*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dropout(input) + 1, m_dropout(input)) 1041*da0073e9SAndroid Build Coastguard Worker 1042*da0073e9SAndroid Build Coastguard Worker def test_nn_lp_pool2d(self): 1043*da0073e9SAndroid Build Coastguard Worker class Mod(torch.nn.Module): 1044*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1045*da0073e9SAndroid Build Coastguard Worker super().__init__() 1046*da0073e9SAndroid Build Coastguard Worker self.l = torch.nn.LPPool2d(2, 3) 1047*da0073e9SAndroid Build Coastguard Worker self.n = torch.nn.LPPool2d(2, (7, 1)) 1048*da0073e9SAndroid Build Coastguard Worker 1049*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1050*da0073e9SAndroid Build Coastguard Worker return (self.l(x), 1051*da0073e9SAndroid Build Coastguard Worker self.n(x), 1052*da0073e9SAndroid Build Coastguard Worker torch.nn.functional.lp_pool2d(x, float(2), 3), 1053*da0073e9SAndroid Build Coastguard Worker torch.nn.functional.lp_pool2d(x, 2, 3), 1054*da0073e9SAndroid Build Coastguard Worker torch.nn.functional.lp_pool2d(x, float(2), (7, 1))) 1055*da0073e9SAndroid Build Coastguard Worker 1056*da0073e9SAndroid Build Coastguard Worker self.checkModule(Mod(), (torch.rand(1, 3, 7, 7),)) 1057*da0073e9SAndroid Build Coastguard Worker 1058*da0073e9SAndroid Build Coastguard Worker def test_nn_lp_pool1d(self): 1059*da0073e9SAndroid Build Coastguard Worker class Mod(torch.nn.Module): 1060*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1061*da0073e9SAndroid Build Coastguard Worker super().__init__() 1062*da0073e9SAndroid Build Coastguard Worker self.l = torch.nn.LPPool1d(2, 3) 1063*da0073e9SAndroid Build Coastguard Worker self.n = torch.nn.LPPool1d(2, 7) 1064*da0073e9SAndroid Build Coastguard Worker 1065*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1066*da0073e9SAndroid Build Coastguard Worker return (self.l(x), 1067*da0073e9SAndroid Build Coastguard Worker self.n(x), 1068*da0073e9SAndroid Build Coastguard Worker torch.nn.functional.lp_pool1d(x, float(2), 3), 1069*da0073e9SAndroid Build Coastguard Worker torch.nn.functional.lp_pool1d(x, 2, 3), 1070*da0073e9SAndroid Build Coastguard Worker torch.nn.functional.lp_pool1d(x, float(2), 7)) 1071*da0073e9SAndroid Build Coastguard Worker 1072*da0073e9SAndroid Build Coastguard Worker self.checkModule(Mod(), (torch.rand(1, 3, 7),)) 1073*da0073e9SAndroid Build Coastguard Worker 1074*da0073e9SAndroid Build Coastguard Worker def test_nn_padding_functional(self): 1075*da0073e9SAndroid Build Coastguard Worker class Mod(nn.Module): 1076*da0073e9SAndroid Build Coastguard Worker def __init__(self, *pad): 1077*da0073e9SAndroid Build Coastguard Worker super().__init__() 1078*da0073e9SAndroid Build Coastguard Worker self.pad = pad 1079*da0073e9SAndroid Build Coastguard Worker 1080*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1081*da0073e9SAndroid Build Coastguard Worker return F.pad(x, self.pad, mode='constant', value=3.5) 1082*da0073e9SAndroid Build Coastguard Worker 1083*da0073e9SAndroid Build Coastguard Worker inputs = [ 1084*da0073e9SAndroid Build Coastguard Worker (Mod(1, 2), torch.randn(1, 3, 4)), # 1D 1085*da0073e9SAndroid Build Coastguard Worker (Mod(1, 2, 3, 4), torch.randn(1, 3, 4)), # 2D 1086*da0073e9SAndroid Build Coastguard Worker (Mod(1, 2, 3, 4, 5, 6), torch.randn(1, 3, 4)), # 3D 1087*da0073e9SAndroid Build Coastguard Worker ] 1088*da0073e9SAndroid Build Coastguard Worker 1089*da0073e9SAndroid Build Coastguard Worker for m, inp in inputs: 1090*da0073e9SAndroid Build Coastguard Worker self.checkModule(m, (inp,)) 1091*da0073e9SAndroid Build Coastguard Worker 1092*da0073e9SAndroid Build Coastguard Worker def test_nn_padding(self): 1093*da0073e9SAndroid Build Coastguard Worker class Mod(nn.Module): 1094*da0073e9SAndroid Build Coastguard Worker def __init__(self, padding): 1095*da0073e9SAndroid Build Coastguard Worker super().__init__() 1096*da0073e9SAndroid Build Coastguard Worker self.padding = padding 1097*da0073e9SAndroid Build Coastguard Worker 1098*da0073e9SAndroid Build Coastguard Worker def forward(self, input): 1099*da0073e9SAndroid Build Coastguard Worker return self.padding(input) 1100*da0073e9SAndroid Build Coastguard Worker 1101*da0073e9SAndroid Build Coastguard Worker inputs = [ 1102*da0073e9SAndroid Build Coastguard Worker (Mod(nn.ConstantPad1d(2, 3.5)), torch.randn(1, 2, 4)), 1103*da0073e9SAndroid Build Coastguard Worker (Mod(nn.ConstantPad2d(2, 3.5)), torch.randn(1, 2, 2)), 1104*da0073e9SAndroid Build Coastguard Worker (Mod(nn.ConstantPad3d(3, 3.5)), torch.randn(16, 3, 10, 20, 30)), 1105*da0073e9SAndroid Build Coastguard Worker (Mod(nn.ReflectionPad1d(2)), torch.arange(8, dtype=torch.float).reshape(1, 2, 4)), 1106*da0073e9SAndroid Build Coastguard Worker (Mod(nn.ReflectionPad2d(2)), torch.arange(9, dtype=torch.float).reshape(1, 1, 3, 3)), 1107*da0073e9SAndroid Build Coastguard Worker (Mod(nn.ReflectionPad3d(3)), torch.randn(16, 3, 8, 32, 48)), 1108*da0073e9SAndroid Build Coastguard Worker (Mod(nn.ReplicationPad1d(2)), torch.arange(8, dtype=torch.float).reshape(1, 2, 4)), 1109*da0073e9SAndroid Build Coastguard Worker (Mod(nn.ReplicationPad2d(2)), torch.arange(9, dtype=torch.float).reshape(1, 1, 3, 3)), 1110*da0073e9SAndroid Build Coastguard Worker (Mod(nn.ReplicationPad3d(3)), torch.randn(16, 3, 8, 32, 48)), 1111*da0073e9SAndroid Build Coastguard Worker (Mod(nn.ZeroPad2d(2)), torch.randn(1, 1, 3, 3)) 1112*da0073e9SAndroid Build Coastguard Worker ] 1113*da0073e9SAndroid Build Coastguard Worker 1114*da0073e9SAndroid Build Coastguard Worker for m, inp in inputs: 1115*da0073e9SAndroid Build Coastguard Worker self.checkModule(m, (inp,)) 1116*da0073e9SAndroid Build Coastguard Worker 1117*da0073e9SAndroid Build Coastguard Worker def test_script_autograd_grad(self): 1118*da0073e9SAndroid Build Coastguard Worker def test_simple_grad(x, y): 1119*da0073e9SAndroid Build Coastguard Worker # type: (Tensor, Tensor) -> List[Optional[Tensor]] 1120*da0073e9SAndroid Build Coastguard Worker z = x + 2 * y + x * y 1121*da0073e9SAndroid Build Coastguard Worker return torch.autograd.grad((z.sum(), ), (x, y)) 1122*da0073e9SAndroid Build Coastguard Worker 1123*da0073e9SAndroid Build Coastguard Worker def test_simple_grad_with_grad_outputs(x, y): 1124*da0073e9SAndroid Build Coastguard Worker # type: (Tensor, Tensor) -> List[Optional[Tensor]] 1125*da0073e9SAndroid Build Coastguard Worker z = x + 2 * y + x * y 1126*da0073e9SAndroid Build Coastguard Worker grad_outputs = torch.jit.annotate(List[Optional[torch.Tensor]], [torch.ones((2, 2)), ]) 1127*da0073e9SAndroid Build Coastguard Worker return torch.autograd.grad((z, ), (x, y), grad_outputs) 1128*da0073e9SAndroid Build Coastguard Worker 1129*da0073e9SAndroid Build Coastguard Worker def test_one_output_not_requires_grad(x, y): 1130*da0073e9SAndroid Build Coastguard Worker # type: (Tensor, Tensor) -> List[Optional[Tensor]] 1131*da0073e9SAndroid Build Coastguard Worker z = 2 * y + y 1132*da0073e9SAndroid Build Coastguard Worker return torch.autograd.grad((z.sum(),), (x, y), allow_unused=True) 1133*da0073e9SAndroid Build Coastguard Worker 1134*da0073e9SAndroid Build Coastguard Worker def test_retain_graph(x, y): 1135*da0073e9SAndroid Build Coastguard Worker # type: (Tensor, Tensor) -> None 1136*da0073e9SAndroid Build Coastguard Worker z = x + 2 * y + x * y 1137*da0073e9SAndroid Build Coastguard Worker torch.autograd.grad((z.sum(), ), (x, y), retain_graph=True) 1138*da0073e9SAndroid Build Coastguard Worker torch.autograd.grad((z.sum(), ), (x, y)) 1139*da0073e9SAndroid Build Coastguard Worker 1140*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 2, requires_grad=True) 1141*da0073e9SAndroid Build Coastguard Worker y = torch.randn(2, 2, requires_grad=True) 1142*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_simple_grad, (x, y), inputs_requires_grad=True) 1143*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_simple_grad_with_grad_outputs, (x, y), inputs_requires_grad=True) 1144*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_one_output_not_requires_grad, (x, y), inputs_requires_grad=True) 1145*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_retain_graph, (x, y), inputs_requires_grad=True) 1146*da0073e9SAndroid Build Coastguard Worker 1147*da0073e9SAndroid Build Coastguard Worker def test_script_backward(self): 1148*da0073e9SAndroid Build Coastguard Worker def checkBackwardScript(fn, inputs): 1149*da0073e9SAndroid Build Coastguard Worker scripted_fn = torch.jit.script(fn) 1150*da0073e9SAndroid Build Coastguard Worker FileCheck().check("torch.autograd.backward").run(scripted_fn.code) 1151*da0073e9SAndroid Build Coastguard Worker recording_inputs = do_input_map(lambda t: t.detach().requires_grad_(), inputs) 1152*da0073e9SAndroid Build Coastguard Worker 1153*da0073e9SAndroid Build Coastguard Worker fn(*inputs) 1154*da0073e9SAndroid Build Coastguard Worker scripted_fn(*recording_inputs) 1155*da0073e9SAndroid Build Coastguard Worker 1156*da0073e9SAndroid Build Coastguard Worker for inp1, inp2 in zip(inputs, recording_inputs): 1157*da0073e9SAndroid Build Coastguard Worker self.assertEqual(inp1.grad, inp2.grad) 1158*da0073e9SAndroid Build Coastguard Worker 1159*da0073e9SAndroid Build Coastguard Worker def test_tensor_backward(input): 1160*da0073e9SAndroid Build Coastguard Worker # type: (Tensor) -> None 1161*da0073e9SAndroid Build Coastguard Worker output = torch.relu(input) 1162*da0073e9SAndroid Build Coastguard Worker output = output.softmax(0) 1163*da0073e9SAndroid Build Coastguard Worker sum_out = output.sum() 1164*da0073e9SAndroid Build Coastguard Worker sum_out.backward() 1165*da0073e9SAndroid Build Coastguard Worker 1166*da0073e9SAndroid Build Coastguard Worker def test_torch_autograd_backward(input): 1167*da0073e9SAndroid Build Coastguard Worker # type: (Tensor) -> None 1168*da0073e9SAndroid Build Coastguard Worker output = torch.relu(input) 1169*da0073e9SAndroid Build Coastguard Worker output = output.softmax(0) 1170*da0073e9SAndroid Build Coastguard Worker torch.autograd.backward(output.sum()) 1171*da0073e9SAndroid Build Coastguard Worker 1172*da0073e9SAndroid Build Coastguard Worker def test_torch_autograd_backward_with_grad_tensors(input): 1173*da0073e9SAndroid Build Coastguard Worker # type: (Tensor) -> None 1174*da0073e9SAndroid Build Coastguard Worker output = torch.relu(input) 1175*da0073e9SAndroid Build Coastguard Worker output = output.softmax(0) 1176*da0073e9SAndroid Build Coastguard Worker grad_outputs = torch.jit.annotate(List[Optional[torch.Tensor]], [torch.ones((2, 2)), ]) 1177*da0073e9SAndroid Build Coastguard Worker torch.autograd.backward((output,), grad_outputs) 1178*da0073e9SAndroid Build Coastguard Worker 1179*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(2, 2, requires_grad=True) 1180*da0073e9SAndroid Build Coastguard Worker checkBackwardScript(test_tensor_backward, (inp,)) 1181*da0073e9SAndroid Build Coastguard Worker checkBackwardScript(test_torch_autograd_backward, (inp,)) 1182*da0073e9SAndroid Build Coastguard Worker checkBackwardScript(test_torch_autograd_backward_with_grad_tensors, (inp,)) 1183*da0073e9SAndroid Build Coastguard Worker 1184*da0073e9SAndroid Build Coastguard Worker def test_script_backward_twice(self): 1185*da0073e9SAndroid Build Coastguard Worker def checkBackwardTwiceScript(fn, inputs, retain_graph_=False): 1186*da0073e9SAndroid Build Coastguard Worker class jit_profiling_executor_false: 1187*da0073e9SAndroid Build Coastguard Worker def __enter__(self): 1188*da0073e9SAndroid Build Coastguard Worker torch._C._jit_set_profiling_executor(False) 1189*da0073e9SAndroid Build Coastguard Worker 1190*da0073e9SAndroid Build Coastguard Worker def __exit__(self, *args): 1191*da0073e9SAndroid Build Coastguard Worker torch._C._jit_set_profiling_executor(GRAPH_EXECUTOR != ProfilingMode.LEGACY) 1192*da0073e9SAndroid Build Coastguard Worker 1193*da0073e9SAndroid Build Coastguard Worker with jit_profiling_executor_false(), torch.jit.optimized_execution(True): 1194*da0073e9SAndroid Build Coastguard Worker scripted_fn = torch.jit.script(fn, inputs) 1195*da0073e9SAndroid Build Coastguard Worker FileCheck().check("prim::DifferentiableGraph").run(scripted_fn.graph_for(*inputs)) 1196*da0073e9SAndroid Build Coastguard Worker 1197*da0073e9SAndroid Build Coastguard Worker result = scripted_fn(*inputs) 1198*da0073e9SAndroid Build Coastguard Worker result.sum().backward(retain_graph=retain_graph_) 1199*da0073e9SAndroid Build Coastguard Worker if not retain_graph_: 1200*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex(RuntimeError, 'Specify retain_graph=True', 1201*da0073e9SAndroid Build Coastguard Worker lambda: result.sum().backward()) 1202*da0073e9SAndroid Build Coastguard Worker else: 1203*da0073e9SAndroid Build Coastguard Worker result.sum().backward() 1204*da0073e9SAndroid Build Coastguard Worker 1205*da0073e9SAndroid Build Coastguard Worker def test_script_backward_twice_with_saved_values(input1, input2): 1206*da0073e9SAndroid Build Coastguard Worker # type: (Tensor, Tensor) -> Tensor 1207*da0073e9SAndroid Build Coastguard Worker tmp1 = torch.mul(input1, input2) 1208*da0073e9SAndroid Build Coastguard Worker tmp2 = torch.abs(tmp1) 1209*da0073e9SAndroid Build Coastguard Worker if torch.equal(input1, input2): 1210*da0073e9SAndroid Build Coastguard Worker tmp2 = torch.acos(tmp2) 1211*da0073e9SAndroid Build Coastguard Worker else: 1212*da0073e9SAndroid Build Coastguard Worker tmp2 = torch.atan(tmp2) 1213*da0073e9SAndroid Build Coastguard Worker result = torch.add(tmp2, input2) 1214*da0073e9SAndroid Build Coastguard Worker return result 1215*da0073e9SAndroid Build Coastguard Worker 1216*da0073e9SAndroid Build Coastguard Worker inp1 = torch.randn(2, 2, requires_grad=True) 1217*da0073e9SAndroid Build Coastguard Worker inp2 = torch.randn(2, 2, requires_grad=True) 1218*da0073e9SAndroid Build Coastguard Worker checkBackwardTwiceScript(test_script_backward_twice_with_saved_values, (inp1, inp2), False) 1219*da0073e9SAndroid Build Coastguard Worker checkBackwardTwiceScript(test_script_backward_twice_with_saved_values, (inp1, inp2), True) 1220*da0073e9SAndroid Build Coastguard Worker 1221*da0073e9SAndroid Build Coastguard Worker def test_diff_subgraph_clones_constants(self): 1222*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 1223*da0073e9SAndroid Build Coastguard Worker def f(x, y): 1224*da0073e9SAndroid Build Coastguard Worker return x + x + y + x + y + x + y + x + y + x 1225*da0073e9SAndroid Build Coastguard Worker 1226*da0073e9SAndroid Build Coastguard Worker def count_constants(graph): 1227*da0073e9SAndroid Build Coastguard Worker return sum(node.kind() == 'prim::Constant' for node in graph.nodes()) 1228*da0073e9SAndroid Build Coastguard Worker 1229*da0073e9SAndroid Build Coastguard Worker graph = f.graph.copy() 1230*da0073e9SAndroid Build Coastguard Worker self.run_pass('cse', graph) 1231*da0073e9SAndroid Build Coastguard Worker self.run_pass('create_autodiff_subgraphs', graph) 1232*da0073e9SAndroid Build Coastguard Worker nodes = list(graph.nodes()) 1233*da0073e9SAndroid Build Coastguard Worker self.assertEqual(count_constants(graph), 1) 1234*da0073e9SAndroid Build Coastguard Worker self.assertEqual(count_constants(nodes[1].g('Subgraph')), 1) 1235*da0073e9SAndroid Build Coastguard Worker 1236*da0073e9SAndroid Build Coastguard Worker # TODO: adapt this test to check that GraphExecutor treats them differently 1237*da0073e9SAndroid Build Coastguard Worker @unittest.skip("Need to be adjusted to Graph Executor") 1238*da0073e9SAndroid Build Coastguard Worker def test_arg_configurations(self): 1239*da0073e9SAndroid Build Coastguard Worker """Different arg configurations should trigger different traces""" 1240*da0073e9SAndroid Build Coastguard Worker x = Variable(torch.FloatTensor(4, 4).uniform_()) 1241*da0073e9SAndroid Build Coastguard Worker x_double = Variable(x.data.double()) 1242*da0073e9SAndroid Build Coastguard Worker x_grad = Variable(x.data.clone(), requires_grad=True) 1243*da0073e9SAndroid Build Coastguard Worker y = Variable(torch.randn(4)) 1244*da0073e9SAndroid Build Coastguard Worker 1245*da0073e9SAndroid Build Coastguard Worker configurations = [ 1246*da0073e9SAndroid Build Coastguard Worker (x,), 1247*da0073e9SAndroid Build Coastguard Worker (x_double,), 1248*da0073e9SAndroid Build Coastguard Worker (x_grad,), 1249*da0073e9SAndroid Build Coastguard Worker (y,), 1250*da0073e9SAndroid Build Coastguard Worker ([x, x],), 1251*da0073e9SAndroid Build Coastguard Worker ([x, y],), 1252*da0073e9SAndroid Build Coastguard Worker ] 1253*da0073e9SAndroid Build Coastguard Worker if torch.cuda.is_available(): 1254*da0073e9SAndroid Build Coastguard Worker x_cuda = Variable(x.data.cuda()) 1255*da0073e9SAndroid Build Coastguard Worker configurations += [ 1256*da0073e9SAndroid Build Coastguard Worker (x_cuda,), 1257*da0073e9SAndroid Build Coastguard Worker ([x, x_cuda],), 1258*da0073e9SAndroid Build Coastguard Worker ([x_cuda, x],), 1259*da0073e9SAndroid Build Coastguard Worker ([[x_cuda, x]],), 1260*da0073e9SAndroid Build Coastguard Worker ] 1261*da0073e9SAndroid Build Coastguard Worker if torch.cuda.device_count() > 1: 1262*da0073e9SAndroid Build Coastguard Worker x_cuda_1 = Variable(x.data.cuda(1)) 1263*da0073e9SAndroid Build Coastguard Worker configurations += [ 1264*da0073e9SAndroid Build Coastguard Worker (x_cuda_1,), 1265*da0073e9SAndroid Build Coastguard Worker ([x_cuda, x_cuda_1],), 1266*da0073e9SAndroid Build Coastguard Worker ] 1267*da0073e9SAndroid Build Coastguard Worker 1268*da0073e9SAndroid Build Coastguard Worker @torch.jit.compile(nderivs=0) 1269*da0073e9SAndroid Build Coastguard Worker def fn(*args): 1270*da0073e9SAndroid Build Coastguard Worker in_vars, _ = torch._C._jit_flatten(args) 1271*da0073e9SAndroid Build Coastguard Worker return in_vars[0] + 1 1272*da0073e9SAndroid Build Coastguard Worker 1273*da0073e9SAndroid Build Coastguard Worker for i, config in enumerate(configurations): 1274*da0073e9SAndroid Build Coastguard Worker self.assertFalse(fn.has_trace_for(*config)) 1275*da0073e9SAndroid Build Coastguard Worker fn(*config) 1276*da0073e9SAndroid Build Coastguard Worker self.assertTrue(fn.has_trace_for(*config)) 1277*da0073e9SAndroid Build Coastguard Worker for unk_config in configurations[i + 1:]: 1278*da0073e9SAndroid Build Coastguard Worker self.assertFalse(fn.has_trace_for(*unk_config)) 1279*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn.hits, 0) 1280*da0073e9SAndroid Build Coastguard Worker 1281*da0073e9SAndroid Build Coastguard Worker def test_torch_sum(self): 1282*da0073e9SAndroid Build Coastguard Worker def fn(x): 1283*da0073e9SAndroid Build Coastguard Worker return torch.sum(x) 1284*da0073e9SAndroid Build Coastguard Worker 1285*da0073e9SAndroid Build Coastguard Worker def fn1(x, dim: int): 1286*da0073e9SAndroid Build Coastguard Worker return torch.sum(x, dim) 1287*da0073e9SAndroid Build Coastguard Worker 1288*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, 4) 1289*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, (x, )) 1290*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn1, (x, 1, )) 1291*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn1, (x, 0, )) 1292*da0073e9SAndroid Build Coastguard Worker 1293*da0073e9SAndroid Build Coastguard Worker def test_cse(self): 1294*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([0.4, 0.3], requires_grad=True) 1295*da0073e9SAndroid Build Coastguard Worker y = torch.tensor([0.7, 0.5], requires_grad=True) 1296*da0073e9SAndroid Build Coastguard Worker 1297*da0073e9SAndroid Build Coastguard Worker def fn(x, y): 1298*da0073e9SAndroid Build Coastguard Worker w = (x + y) * (x + y) * (x + y) 1299*da0073e9SAndroid Build Coastguard Worker t = torch.tanh(w) + torch.tanh(w) 1300*da0073e9SAndroid Build Coastguard Worker z = (x + y) * (x + y) * (x + y) + t 1301*da0073e9SAndroid Build Coastguard Worker return z 1302*da0073e9SAndroid Build Coastguard Worker 1303*da0073e9SAndroid Build Coastguard Worker g, _ = torch.jit._get_trace_graph(fn, (x, y)) 1304*da0073e9SAndroid Build Coastguard Worker self.run_pass('cse', g) 1305*da0073e9SAndroid Build Coastguard Worker do_exactly = True 1306*da0073e9SAndroid Build Coastguard Worker FileCheck().check_count("add", 1).check_count("mul", 2, do_exactly) \ 1307*da0073e9SAndroid Build Coastguard Worker .check_count("tanh", 1, do_exactly).check_count("add", 2, do_exactly).check_next("return") \ 1308*da0073e9SAndroid Build Coastguard Worker .run(str(g)) 1309*da0073e9SAndroid Build Coastguard Worker 1310*da0073e9SAndroid Build Coastguard Worker self.assertExportImport(g, (x, y)) 1311*da0073e9SAndroid Build Coastguard Worker 1312*da0073e9SAndroid Build Coastguard Worker def test_cse_not_introduce_aliasing(self): 1313*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 1314*da0073e9SAndroid Build Coastguard Worker def tensor_alias_outputs(x): 1315*da0073e9SAndroid Build Coastguard Worker return x + x, x + x 1316*da0073e9SAndroid Build Coastguard Worker 1317*da0073e9SAndroid Build Coastguard Worker self.run_pass('cse', tensor_alias_outputs.graph) 1318*da0073e9SAndroid Build Coastguard Worker FileCheck().check_count("aten::add", 2).run(tensor_alias_outputs.graph) 1319*da0073e9SAndroid Build Coastguard Worker 1320*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 1321*da0073e9SAndroid Build Coastguard Worker def ints_alias_outputs(x): 1322*da0073e9SAndroid Build Coastguard Worker # type: (int) -> Tuple[int, int] 1323*da0073e9SAndroid Build Coastguard Worker return x + x, x + x 1324*da0073e9SAndroid Build Coastguard Worker 1325*da0073e9SAndroid Build Coastguard Worker # non-aliasing types can be CSEd 1326*da0073e9SAndroid Build Coastguard Worker self.run_pass('cse', ints_alias_outputs.graph) 1327*da0073e9SAndroid Build Coastguard Worker FileCheck().check_count("aten::add", 1, exactly=True).run(ints_alias_outputs.graph) 1328*da0073e9SAndroid Build Coastguard Worker 1329*da0073e9SAndroid Build Coastguard Worker def test_recursive_cse(self): 1330*da0073e9SAndroid Build Coastguard Worker input_str = """ 1331*da0073e9SAndroid Build Coastguard Workergraph(%x : Tensor, 1332*da0073e9SAndroid Build Coastguard Worker %y : Tensor, 1333*da0073e9SAndroid Build Coastguard Worker %20 : int): 1334*da0073e9SAndroid Build Coastguard Worker %2 : int = prim::Constant[value=1]() 1335*da0073e9SAndroid Build Coastguard Worker %3 : Tensor = aten::add(%x, %y, %2) 1336*da0073e9SAndroid Build Coastguard Worker %4 : int = aten::add(%2, %20) 1337*da0073e9SAndroid Build Coastguard Worker %5 : bool = aten::Bool(%4) 1338*da0073e9SAndroid Build Coastguard Worker %z : int = prim::If(%5) 1339*da0073e9SAndroid Build Coastguard Worker # CHECK: block 1340*da0073e9SAndroid Build Coastguard Worker block0(): 1341*da0073e9SAndroid Build Coastguard Worker # CHECK-NOT: aten::add 1342*da0073e9SAndroid Build Coastguard Worker %z.1 : int = aten::add(%2, %20) 1343*da0073e9SAndroid Build Coastguard Worker -> (%z.1) 1344*da0073e9SAndroid Build Coastguard Worker block1(): 1345*da0073e9SAndroid Build Coastguard Worker -> (%2) 1346*da0073e9SAndroid Build Coastguard Worker return (%z) 1347*da0073e9SAndroid Build Coastguard Worker""" 1348*da0073e9SAndroid Build Coastguard Worker graph = parse_ir(input_str) 1349*da0073e9SAndroid Build Coastguard Worker self.run_pass('cse', graph) 1350*da0073e9SAndroid Build Coastguard Worker FileCheck().run(input_str, graph) 1351*da0073e9SAndroid Build Coastguard Worker 1352*da0073e9SAndroid Build Coastguard Worker def test_pattern_based_rewrite(self): 1353*da0073e9SAndroid Build Coastguard Worker # mul(mul(mul(mul(x,y),z),x),y) --> mul(mul(mulmul(x,y,z), x), y) --> 1354*da0073e9SAndroid Build Coastguard Worker # --> mulmul(mulmul(x,y,z), x, y) 1355*da0073e9SAndroid Build Coastguard Worker input_str = """ 1356*da0073e9SAndroid Build Coastguard Workergraph(%x, %y, %z): 1357*da0073e9SAndroid Build Coastguard Worker # CHECK-NOT: aten::mul 1358*da0073e9SAndroid Build Coastguard Worker # CHECK: my::fused_mulmul 1359*da0073e9SAndroid Build Coastguard Worker %t = aten::mul(%x, %y) 1360*da0073e9SAndroid Build Coastguard Worker %p = aten::mul(%t, %z) 1361*da0073e9SAndroid Build Coastguard Worker # CHECK: my::fused_mulmul 1362*da0073e9SAndroid Build Coastguard Worker %u = aten::mul(%p, %x) 1363*da0073e9SAndroid Build Coastguard Worker %o = aten::mul(%u, %y) 1364*da0073e9SAndroid Build Coastguard Worker return (%o)""" 1365*da0073e9SAndroid Build Coastguard Worker graph = parse_ir(input_str) 1366*da0073e9SAndroid Build Coastguard Worker torch._C._jit_pass_custom_pattern_based_rewrite_graph(""" 1367*da0073e9SAndroid Build Coastguard Workergraph(%a, %b, %c): 1368*da0073e9SAndroid Build Coastguard Worker %q = aten::mul(%a, %b) 1369*da0073e9SAndroid Build Coastguard Worker %r = aten::mul(%q, %c) 1370*da0073e9SAndroid Build Coastguard Worker return (%r)""", """ 1371*da0073e9SAndroid Build Coastguard Workergraph(%a, %b, %c): 1372*da0073e9SAndroid Build Coastguard Worker %r = my::fused_mulmul(%a, %b, %c) 1373*da0073e9SAndroid Build Coastguard Worker return (%r)""", graph) 1374*da0073e9SAndroid Build Coastguard Worker FileCheck().run(input_str, graph) 1375*da0073e9SAndroid Build Coastguard Worker 1376*da0073e9SAndroid Build Coastguard Worker # Check that overlapping matches are handled correctly 1377*da0073e9SAndroid Build Coastguard Worker # mul(mul(mul(x,y),z),x) --> mul(mulmul(x,y,z), x) 1378*da0073e9SAndroid Build Coastguard Worker input_str = """ 1379*da0073e9SAndroid Build Coastguard Workergraph(%x, %y, %z): 1380*da0073e9SAndroid Build Coastguard Worker # CHECK-NOT: aten::mul 1381*da0073e9SAndroid Build Coastguard Worker # CHECK: my::fused_mulmul 1382*da0073e9SAndroid Build Coastguard Worker %t = aten::mul(%x, %y) 1383*da0073e9SAndroid Build Coastguard Worker %p = aten::mul(%t, %z) 1384*da0073e9SAndroid Build Coastguard Worker # CHECK-NEXT: aten::mul 1385*da0073e9SAndroid Build Coastguard Worker %u = aten::mul(%p, %x) 1386*da0073e9SAndroid Build Coastguard Worker return (%u)""" 1387*da0073e9SAndroid Build Coastguard Worker graph = parse_ir(input_str) 1388*da0073e9SAndroid Build Coastguard Worker torch._C._jit_pass_custom_pattern_based_rewrite_graph(""" 1389*da0073e9SAndroid Build Coastguard Workergraph(%a, %b, %c): 1390*da0073e9SAndroid Build Coastguard Worker %q = aten::mul(%a, %b) 1391*da0073e9SAndroid Build Coastguard Worker %r = aten::mul(%q, %c) 1392*da0073e9SAndroid Build Coastguard Worker return (%r)""", """ 1393*da0073e9SAndroid Build Coastguard Workergraph(%a, %b, %c): 1394*da0073e9SAndroid Build Coastguard Worker %r = my::fused_mulmul(%a, %b, %c) 1395*da0073e9SAndroid Build Coastguard Worker return (%r)""", graph) 1396*da0073e9SAndroid Build Coastguard Worker FileCheck().run(input_str, graph) 1397*da0073e9SAndroid Build Coastguard Worker 1398*da0073e9SAndroid Build Coastguard Worker # Check add(mul(x,y),z) --> muladd(x,y,z) replacement 1399*da0073e9SAndroid Build Coastguard Worker input_str = """ 1400*da0073e9SAndroid Build Coastguard Workergraph(%x, %y, %z): 1401*da0073e9SAndroid Build Coastguard Worker # CHECK-NOT: aten::mul 1402*da0073e9SAndroid Build Coastguard Worker # CHECK-NOT: aten::add 1403*da0073e9SAndroid Build Coastguard Worker %c = prim::Const[value=1]() 1404*da0073e9SAndroid Build Coastguard Worker %t = aten::mul(%x, %y) 1405*da0073e9SAndroid Build Coastguard Worker %p = aten::add(%t, %z, %c) 1406*da0073e9SAndroid Build Coastguard Worker # CHECK: my::muladd 1407*da0073e9SAndroid Build Coastguard Worker # CHECK-NEXT: return 1408*da0073e9SAndroid Build Coastguard Worker return (%p)""" 1409*da0073e9SAndroid Build Coastguard Worker graph = parse_ir(input_str) 1410*da0073e9SAndroid Build Coastguard Worker torch._C._jit_pass_custom_pattern_based_rewrite_graph(""" 1411*da0073e9SAndroid Build Coastguard Workergraph(%a, %b, %c, %d): 1412*da0073e9SAndroid Build Coastguard Worker %q = aten::mul(%a, %b) 1413*da0073e9SAndroid Build Coastguard Worker %r = aten::add(%q, %c, %d) 1414*da0073e9SAndroid Build Coastguard Worker return (%r)""", """ 1415*da0073e9SAndroid Build Coastguard Workergraph(%a, %b, %c, %d): 1416*da0073e9SAndroid Build Coastguard Worker %r = my::muladd(%a, %b, %c, %d) 1417*da0073e9SAndroid Build Coastguard Worker return (%r)""", graph) 1418*da0073e9SAndroid Build Coastguard Worker FileCheck().run(input_str, graph) 1419*da0073e9SAndroid Build Coastguard Worker 1420*da0073e9SAndroid Build Coastguard Worker # Check add(mul(x,y),z) --> sub(add(x,y),z) replacement 1421*da0073e9SAndroid Build Coastguard Worker input_str = """ 1422*da0073e9SAndroid Build Coastguard Workergraph(%x, %y, %z): 1423*da0073e9SAndroid Build Coastguard Worker # CHECK-NOT: aten::mul 1424*da0073e9SAndroid Build Coastguard Worker %c = prim::Const[value=1]() 1425*da0073e9SAndroid Build Coastguard Worker # CHECK: aten::add 1426*da0073e9SAndroid Build Coastguard Worker %t = aten::mul(%x, %y) 1427*da0073e9SAndroid Build Coastguard Worker # CHECK-NEXT: aten::sub 1428*da0073e9SAndroid Build Coastguard Worker %p = aten::add(%t, %z, %c) 1429*da0073e9SAndroid Build Coastguard Worker # CHECK-NOT: aten::add 1430*da0073e9SAndroid Build Coastguard Worker # CHECK-NEXT: return 1431*da0073e9SAndroid Build Coastguard Worker return (%p)""" 1432*da0073e9SAndroid Build Coastguard Worker graph = parse_ir(input_str) 1433*da0073e9SAndroid Build Coastguard Worker torch._C._jit_pass_custom_pattern_based_rewrite_graph(""" 1434*da0073e9SAndroid Build Coastguard Workergraph(%a, %b, %c, %d): 1435*da0073e9SAndroid Build Coastguard Worker %q = aten::mul(%a, %b) 1436*da0073e9SAndroid Build Coastguard Worker %r = aten::add(%q, %c, %d) 1437*da0073e9SAndroid Build Coastguard Worker return (%r)""", """ 1438*da0073e9SAndroid Build Coastguard Workergraph(%a, %b, %c, %d): 1439*da0073e9SAndroid Build Coastguard Worker %q = aten::add(%a, %b, %d) 1440*da0073e9SAndroid Build Coastguard Worker %r = aten::sub(%q, %c, %d) 1441*da0073e9SAndroid Build Coastguard Worker return (%r)""", graph) 1442*da0073e9SAndroid Build Coastguard Worker FileCheck().run(input_str, graph) 1443*da0073e9SAndroid Build Coastguard Worker 1444*da0073e9SAndroid Build Coastguard Worker # Check mul(x,y) --> x replacement 1445*da0073e9SAndroid Build Coastguard Worker input_str = """ 1446*da0073e9SAndroid Build Coastguard Workergraph(%x, %y, %z): 1447*da0073e9SAndroid Build Coastguard Worker %c = prim::Const[value=1]() 1448*da0073e9SAndroid Build Coastguard Worker # CHECK-NOT: aten::mul 1449*da0073e9SAndroid Build Coastguard Worker %t = aten::mul(%x, %y) 1450*da0073e9SAndroid Build Coastguard Worker # CHECK: aten::add(%x, %z 1451*da0073e9SAndroid Build Coastguard Worker %p = aten::add(%t, %z, %c) 1452*da0073e9SAndroid Build Coastguard Worker # CHECK-NEXT: return 1453*da0073e9SAndroid Build Coastguard Worker return (%p)""" 1454*da0073e9SAndroid Build Coastguard Worker graph = parse_ir(input_str) 1455*da0073e9SAndroid Build Coastguard Worker torch._C._jit_pass_custom_pattern_based_rewrite_graph(""" 1456*da0073e9SAndroid Build Coastguard Workergraph(%Pa, %Pb): 1457*da0073e9SAndroid Build Coastguard Worker %Pq = aten::mul(%Pa, %Pb) 1458*da0073e9SAndroid Build Coastguard Worker return (%Pq)""", """ 1459*da0073e9SAndroid Build Coastguard Workergraph(%Ra, %Rb): 1460*da0073e9SAndroid Build Coastguard Worker return (%Ra)""", graph) 1461*da0073e9SAndroid Build Coastguard Worker FileCheck().run(input_str, graph) 1462*da0073e9SAndroid Build Coastguard Worker 1463*da0073e9SAndroid Build Coastguard Worker @_tmp_donotuse_dont_inline_everything 1464*da0073e9SAndroid Build Coastguard Worker def test_pattern_based_module_rewrite(self): 1465*da0073e9SAndroid Build Coastguard Worker # Check match::module behavior 1466*da0073e9SAndroid Build Coastguard Worker class Test(torch.nn.Module): 1467*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1468*da0073e9SAndroid Build Coastguard Worker super().__init__() 1469*da0073e9SAndroid Build Coastguard Worker self.conv = torch.nn.Conv2d(1, 20, 5, 1) 1470*da0073e9SAndroid Build Coastguard Worker self.bn = torch.nn.BatchNorm2d(num_features=20) 1471*da0073e9SAndroid Build Coastguard Worker 1472*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 1473*da0073e9SAndroid Build Coastguard Worker x = self.conv(x) 1474*da0073e9SAndroid Build Coastguard Worker x = self.bn(x) 1475*da0073e9SAndroid Build Coastguard Worker return x 1476*da0073e9SAndroid Build Coastguard Worker m = torch.jit.script(Test()) 1477*da0073e9SAndroid Build Coastguard Worker torch._C._jit_pass_custom_pattern_based_rewrite_graph(""" 1478*da0073e9SAndroid Build Coastguard Worker graph(%self, %x): 1479*da0073e9SAndroid Build Coastguard Worker %conv = match::module[name="Conv2d"](%self) 1480*da0073e9SAndroid Build Coastguard Worker %y = prim::CallMethod[name="forward"](%conv, %x) 1481*da0073e9SAndroid Build Coastguard Worker %bn = match::module[name="BatchNorm2d"](%self) 1482*da0073e9SAndroid Build Coastguard Worker %z = prim::CallMethod[name="forward"](%bn, %y) 1483*da0073e9SAndroid Build Coastguard Worker return (%z)""", """ 1484*da0073e9SAndroid Build Coastguard Worker graph(%self, %x): 1485*da0073e9SAndroid Build Coastguard Worker %z = my::matched_conv_bn(%self, %x) 1486*da0073e9SAndroid Build Coastguard Worker return (%z)""", m._c._get_method("forward").graph) 1487*da0073e9SAndroid Build Coastguard Worker 1488*da0073e9SAndroid Build Coastguard Worker FileCheck().check("my::matched_conv_bn").run(m._c._get_method("forward").graph) 1489*da0073e9SAndroid Build Coastguard Worker 1490*da0073e9SAndroid Build Coastguard Worker def test_pattern_based_rewrite_with_source_range_preserved(self): 1491*da0073e9SAndroid Build Coastguard Worker class TestModule1(torch.nn.Module): 1492*da0073e9SAndroid Build Coastguard Worker def forward(self, x, y, z, w): 1493*da0073e9SAndroid Build Coastguard Worker x = x + y 1494*da0073e9SAndroid Build Coastguard Worker x = x * z 1495*da0073e9SAndroid Build Coastguard Worker return w - x 1496*da0073e9SAndroid Build Coastguard Worker 1497*da0073e9SAndroid Build Coastguard Worker input_pattern = """ 1498*da0073e9SAndroid Build Coastguard Worker graph(%x, %y, %z, %const): 1499*da0073e9SAndroid Build Coastguard Worker %t = aten::add(%x, %y, %const) 1500*da0073e9SAndroid Build Coastguard Worker %o = aten::mul(%t, %z) 1501*da0073e9SAndroid Build Coastguard Worker return (%o)""" 1502*da0073e9SAndroid Build Coastguard Worker replacement_pattern = """ 1503*da0073e9SAndroid Build Coastguard Worker graph(%x, %y, %z, %const): 1504*da0073e9SAndroid Build Coastguard Worker %o = my::add_mul(%x, %y, %z, %const) 1505*da0073e9SAndroid Build Coastguard Worker return (%o)""" 1506*da0073e9SAndroid Build Coastguard Worker scripted_model = torch.jit.script(TestModule1()) 1507*da0073e9SAndroid Build Coastguard Worker graph = scripted_model.graph 1508*da0073e9SAndroid Build Coastguard Worker value_mappings = [("o", "t")] 1509*da0073e9SAndroid Build Coastguard Worker for node in graph.nodes(): 1510*da0073e9SAndroid Build Coastguard Worker if node.kind() == "aten::add": 1511*da0073e9SAndroid Build Coastguard Worker source_range_1 = node.sourceRange() 1512*da0073e9SAndroid Build Coastguard Worker torch._C._jit_pass_custom_pattern_based_rewrite_graph( 1513*da0073e9SAndroid Build Coastguard Worker input_pattern, replacement_pattern, scripted_model.graph, value_name_pairs=value_mappings) 1514*da0073e9SAndroid Build Coastguard Worker graph = scripted_model.graph 1515*da0073e9SAndroid Build Coastguard Worker for node in graph.nodes(): 1516*da0073e9SAndroid Build Coastguard Worker if node.kind() == "my::add_mul": 1517*da0073e9SAndroid Build Coastguard Worker source_range_2 = node.sourceRange() 1518*da0073e9SAndroid Build Coastguard Worker self.assertTrue(source_range_1 == source_range_2) 1519*da0073e9SAndroid Build Coastguard Worker 1520*da0073e9SAndroid Build Coastguard Worker class TestModule2(torch.nn.Module): 1521*da0073e9SAndroid Build Coastguard Worker def forward(self, x, y, z, w): 1522*da0073e9SAndroid Build Coastguard Worker x = x + y 1523*da0073e9SAndroid Build Coastguard Worker x = x + z 1524*da0073e9SAndroid Build Coastguard Worker x = x * z 1525*da0073e9SAndroid Build Coastguard Worker x = x * w 1526*da0073e9SAndroid Build Coastguard Worker return x - 2 1527*da0073e9SAndroid Build Coastguard Worker 1528*da0073e9SAndroid Build Coastguard Worker # Check source range preservation for two node transforms add -> my_add 1529*da0073e9SAndroid Build Coastguard Worker input_pattern = """ 1530*da0073e9SAndroid Build Coastguard Worker graph(%x, %y, %const): 1531*da0073e9SAndroid Build Coastguard Worker %o = aten::add(%x, %y, %const) 1532*da0073e9SAndroid Build Coastguard Worker return (%o)""" 1533*da0073e9SAndroid Build Coastguard Worker replacement_pattern = """ 1534*da0073e9SAndroid Build Coastguard Worker graph(%x, %y, %const): 1535*da0073e9SAndroid Build Coastguard Worker %o = my::add(%x, %y, %const) 1536*da0073e9SAndroid Build Coastguard Worker return (%o)""" 1537*da0073e9SAndroid Build Coastguard Worker scripted_model = copy.deepcopy(torch.jit.script(TestModule2())) 1538*da0073e9SAndroid Build Coastguard Worker graph_copy = scripted_model.graph.copy() 1539*da0073e9SAndroid Build Coastguard Worker value_mappings = [("o", "o")] 1540*da0073e9SAndroid Build Coastguard Worker source_range_add_1 = None 1541*da0073e9SAndroid Build Coastguard Worker for node in graph_copy.nodes(): 1542*da0073e9SAndroid Build Coastguard Worker if source_range_add_1 is None and node.kind() == "aten::add": 1543*da0073e9SAndroid Build Coastguard Worker source_range_add_1 = node.sourceRange() 1544*da0073e9SAndroid Build Coastguard Worker if source_range_add_1 is not None and node.kind() == "aten::add": 1545*da0073e9SAndroid Build Coastguard Worker source_range_add_2 = node.sourceRange() 1546*da0073e9SAndroid Build Coastguard Worker torch._C._jit_pass_custom_pattern_based_rewrite_graph( 1547*da0073e9SAndroid Build Coastguard Worker input_pattern, replacement_pattern, graph_copy, value_name_pairs=value_mappings) 1548*da0073e9SAndroid Build Coastguard Worker source_range_my_add_1 = None 1549*da0073e9SAndroid Build Coastguard Worker for node in graph_copy.nodes(): 1550*da0073e9SAndroid Build Coastguard Worker if source_range_my_add_1 is None and node.kind() == "my::add": 1551*da0073e9SAndroid Build Coastguard Worker source_range_my_add_1 = node.sourceRange() 1552*da0073e9SAndroid Build Coastguard Worker if source_range_my_add_1 is not None and node.kind() == "my::add": 1553*da0073e9SAndroid Build Coastguard Worker source_range_my_add_2 = node.sourceRange() 1554*da0073e9SAndroid Build Coastguard Worker self.assertTrue(source_range_add_1 == source_range_my_add_1) 1555*da0073e9SAndroid Build Coastguard Worker self.assertTrue(source_range_add_2 == source_range_my_add_2) 1556*da0073e9SAndroid Build Coastguard Worker 1557*da0073e9SAndroid Build Coastguard Worker # Check source range preservation for add-add -> double_add transform 1558*da0073e9SAndroid Build Coastguard Worker # fuse nodes 1559*da0073e9SAndroid Build Coastguard Worker input_pattern = """ 1560*da0073e9SAndroid Build Coastguard Worker graph(%x, %y, %z, %const): 1561*da0073e9SAndroid Build Coastguard Worker %t = aten::add(%x, %y, %const) 1562*da0073e9SAndroid Build Coastguard Worker %o = aten::add(%t, %z, %const) 1563*da0073e9SAndroid Build Coastguard Worker return (%o)""" 1564*da0073e9SAndroid Build Coastguard Worker replacement_pattern = """ 1565*da0073e9SAndroid Build Coastguard Worker graph(%x, %y, %z, %const): 1566*da0073e9SAndroid Build Coastguard Worker %o = my::double_add(%x, %y, %z, %const) 1567*da0073e9SAndroid Build Coastguard Worker return (%o)""" 1568*da0073e9SAndroid Build Coastguard Worker scripted_model = torch.jit.script(TestModule2()) 1569*da0073e9SAndroid Build Coastguard Worker graph_copy = scripted_model.graph.copy() 1570*da0073e9SAndroid Build Coastguard Worker value_mappings = [("o", "t")] 1571*da0073e9SAndroid Build Coastguard Worker source_range_1 = None 1572*da0073e9SAndroid Build Coastguard Worker source_range_2 = None 1573*da0073e9SAndroid Build Coastguard Worker for node in graph_copy.nodes(): 1574*da0073e9SAndroid Build Coastguard Worker if node.kind() == "aten::add": 1575*da0073e9SAndroid Build Coastguard Worker source_range_1 = node.sourceRange() 1576*da0073e9SAndroid Build Coastguard Worker break 1577*da0073e9SAndroid Build Coastguard Worker torch._C._jit_pass_custom_pattern_based_rewrite_graph( 1578*da0073e9SAndroid Build Coastguard Worker input_pattern, replacement_pattern, graph_copy, value_name_pairs=value_mappings) 1579*da0073e9SAndroid Build Coastguard Worker for node in graph_copy.nodes(): 1580*da0073e9SAndroid Build Coastguard Worker if node.kind() == "my::double_add": 1581*da0073e9SAndroid Build Coastguard Worker source_range_2 = node.sourceRange() 1582*da0073e9SAndroid Build Coastguard Worker self.assertTrue(source_range_1 == source_range_2) 1583*da0073e9SAndroid Build Coastguard Worker 1584*da0073e9SAndroid Build Coastguard Worker # Check source range preservation for mul -> add + add transform 1585*da0073e9SAndroid Build Coastguard Worker # split node 1586*da0073e9SAndroid Build Coastguard Worker input_pattern = """ 1587*da0073e9SAndroid Build Coastguard Worker graph(%x, %y): 1588*da0073e9SAndroid Build Coastguard Worker %t = aten::mul(%x, %y) 1589*da0073e9SAndroid Build Coastguard Worker return (%t)""" 1590*da0073e9SAndroid Build Coastguard Worker replacement_pattern = """ 1591*da0073e9SAndroid Build Coastguard Worker graph(%x, %y): 1592*da0073e9SAndroid Build Coastguard Worker %t = my::add(%x, %y) 1593*da0073e9SAndroid Build Coastguard Worker %o = my::add(%t, %y) 1594*da0073e9SAndroid Build Coastguard Worker return (%o)""" 1595*da0073e9SAndroid Build Coastguard Worker scripted_model = torch.jit.script(TestModule2()) 1596*da0073e9SAndroid Build Coastguard Worker graph_copy = scripted_model.graph.copy() 1597*da0073e9SAndroid Build Coastguard Worker value_mappings = [("t", "t"), ("o", "t")] 1598*da0073e9SAndroid Build Coastguard Worker source_range_mul_1 = None 1599*da0073e9SAndroid Build Coastguard Worker for node in graph_copy.nodes(): 1600*da0073e9SAndroid Build Coastguard Worker if source_range_mul_1 is None and node.kind() == "aten::mul": 1601*da0073e9SAndroid Build Coastguard Worker source_range_mul_1 = node.sourceRange() 1602*da0073e9SAndroid Build Coastguard Worker if source_range_mul_1 is not None and node.kind() == "aten::mul": 1603*da0073e9SAndroid Build Coastguard Worker source_range_mul_2 = node.sourceRange() 1604*da0073e9SAndroid Build Coastguard Worker torch._C._jit_pass_custom_pattern_based_rewrite_graph( 1605*da0073e9SAndroid Build Coastguard Worker input_pattern, replacement_pattern, graph_copy, value_name_pairs=value_mappings) 1606*da0073e9SAndroid Build Coastguard Worker source_range_add_1 = None 1607*da0073e9SAndroid Build Coastguard Worker for node in graph_copy.nodes(): 1608*da0073e9SAndroid Build Coastguard Worker if source_range_add_1 is None and node.kind() == "my::add": 1609*da0073e9SAndroid Build Coastguard Worker source_range_add_1 = node.sourceRange() 1610*da0073e9SAndroid Build Coastguard Worker if source_range_add_1 is not None and node.kind() == "my::add": 1611*da0073e9SAndroid Build Coastguard Worker source_range_add_2 = node.sourceRange() 1612*da0073e9SAndroid Build Coastguard Worker self.assertTrue(source_range_mul_1 == source_range_add_1) 1613*da0073e9SAndroid Build Coastguard Worker self.assertTrue(source_range_mul_2 == source_range_add_2) 1614*da0073e9SAndroid Build Coastguard Worker 1615*da0073e9SAndroid Build Coastguard Worker # Check lack of source range preservation for mul-mul-> double_mul transform 1616*da0073e9SAndroid Build Coastguard Worker input_pattern = """ 1617*da0073e9SAndroid Build Coastguard Worker graph(%x, %y, %z): 1618*da0073e9SAndroid Build Coastguard Worker %t = aten::mul(%x, %y) 1619*da0073e9SAndroid Build Coastguard Worker %o = aten::mul(%t, %z) 1620*da0073e9SAndroid Build Coastguard Worker return (%o)""" 1621*da0073e9SAndroid Build Coastguard Worker replacement_pattern = """ 1622*da0073e9SAndroid Build Coastguard Worker graph(%x, %y, %z): 1623*da0073e9SAndroid Build Coastguard Worker %o = my::double_mul(%x, %y, %z) 1624*da0073e9SAndroid Build Coastguard Worker return (%o)""" 1625*da0073e9SAndroid Build Coastguard Worker scripted_model = torch.jit.script(TestModule2()) 1626*da0073e9SAndroid Build Coastguard Worker graph_copy = scripted_model.graph.copy() 1627*da0073e9SAndroid Build Coastguard Worker for node in graph_copy.nodes(): 1628*da0073e9SAndroid Build Coastguard Worker if node.kind() == "aten::mul": 1629*da0073e9SAndroid Build Coastguard Worker source_range_1 = node.sourceRange() 1630*da0073e9SAndroid Build Coastguard Worker torch._C._jit_pass_custom_pattern_based_rewrite_graph(input_pattern, replacement_pattern, graph_copy) 1631*da0073e9SAndroid Build Coastguard Worker for node in graph_copy.nodes(): 1632*da0073e9SAndroid Build Coastguard Worker if node.kind() == "my::double_mul": 1633*da0073e9SAndroid Build Coastguard Worker source_range_2 = node.sourceRange() 1634*da0073e9SAndroid Build Coastguard Worker self.assertFalse(source_range_1 == source_range_2) 1635*da0073e9SAndroid Build Coastguard Worker 1636*da0073e9SAndroid Build Coastguard Worker def test_expand_quantlint(self): 1637*da0073e9SAndroid Build Coastguard Worker pass 1638*da0073e9SAndroid Build Coastguard Worker 1639*da0073e9SAndroid Build Coastguard Worker def test_expand_fold_quant_inputs(self): 1640*da0073e9SAndroid Build Coastguard Worker pass 1641*da0073e9SAndroid Build Coastguard Worker 1642*da0073e9SAndroid Build Coastguard Worker def test_shape_analysis_broadcast(self): 1643*da0073e9SAndroid Build Coastguard Worker def broadcast(a, b): 1644*da0073e9SAndroid Build Coastguard Worker return a + b 1645*da0073e9SAndroid Build Coastguard Worker 1646*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, 1, 5, requires_grad=True) 1647*da0073e9SAndroid Build Coastguard Worker y = torch.randn(4, 1, 8, 5, requires_grad=True) 1648*da0073e9SAndroid Build Coastguard Worker 1649*da0073e9SAndroid Build Coastguard Worker graph = torch.jit.script(broadcast).graph 1650*da0073e9SAndroid Build Coastguard Worker torch._C._jit_pass_complete_shape_analysis(graph, (x, y), False) 1651*da0073e9SAndroid Build Coastguard Worker FileCheck().check("Float(4, 3, 8, 5, strides=[120, 40, 5, 1], device=cpu)").run(str(graph)) 1652*da0073e9SAndroid Build Coastguard Worker 1653*da0073e9SAndroid Build Coastguard Worker def test_shape_analysis_unsqueeze_in_loop(self): 1654*da0073e9SAndroid Build Coastguard Worker input_str = """graph(%x.1 : Tensor): 1655*da0073e9SAndroid Build Coastguard Worker %4 : bool = prim::Constant[value=1]() 1656*da0073e9SAndroid Build Coastguard Worker %1 : int = prim::Constant[value=2]() 1657*da0073e9SAndroid Build Coastguard Worker %7 : int = prim::Constant[value=0]() 1658*da0073e9SAndroid Build Coastguard Worker # CHECK: FloatTensor(requires_grad=0, device=cpu) = prim::Loop 1659*da0073e9SAndroid Build Coastguard Worker %x : Tensor = prim::Loop(%1, %4, %x.1) 1660*da0073e9SAndroid Build Coastguard Worker # CHECK: : FloatTensor(requires_grad=0, device=cpu)): 1661*da0073e9SAndroid Build Coastguard Worker block0(%i : int, %x.6 : Tensor): 1662*da0073e9SAndroid Build Coastguard Worker # CHECK: FloatTensor(requires_grad=0, device=cpu) = aten::unsqueeze 1663*da0073e9SAndroid Build Coastguard Worker %x.3 : Tensor = aten::unsqueeze(%x.6, %7) 1664*da0073e9SAndroid Build Coastguard Worker -> (%4, %x.3) 1665*da0073e9SAndroid Build Coastguard Worker return (%x)""" 1666*da0073e9SAndroid Build Coastguard Worker graph = parse_ir(input_str) 1667*da0073e9SAndroid Build Coastguard Worker torch._C._jit_pass_complete_shape_analysis(graph, (torch.zeros(2, 2, dtype=torch.float32),), False) 1668*da0073e9SAndroid Build Coastguard Worker FileCheck().run(input_str, graph) 1669*da0073e9SAndroid Build Coastguard Worker 1670*da0073e9SAndroid Build Coastguard Worker def test_script_tensor_type(self): 1671*da0073e9SAndroid Build Coastguard Worker def foo(x, t: torch.dtype): 1672*da0073e9SAndroid Build Coastguard Worker return x.type(t) 1673*da0073e9SAndroid Build Coastguard Worker scr = torch.jit.script(foo) 1674*da0073e9SAndroid Build Coastguard Worker x = torch.rand(3, 4) 1675*da0073e9SAndroid Build Coastguard Worker for t in [torch.int8, torch.float64, torch.float32, 1676*da0073e9SAndroid Build Coastguard Worker torch.bfloat16, torch.complex64, torch.complex128, torch.bool]: 1677*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scr(x, t), foo(x, t)) 1678*da0073e9SAndroid Build Coastguard Worker 1679*da0073e9SAndroid Build Coastguard Worker def test_script_bool_literal_conversion(self): 1680*da0073e9SAndroid Build Coastguard Worker def foo(x): 1681*da0073e9SAndroid Build Coastguard Worker return torch.mul(x, True) 1682*da0073e9SAndroid Build Coastguard Worker scr = torch.jit.script(foo) 1683*da0073e9SAndroid Build Coastguard Worker x = torch.rand(3, 4) 1684*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scr(x), foo(x)) 1685*da0073e9SAndroid Build Coastguard Worker 1686*da0073e9SAndroid Build Coastguard Worker def test_shape_analysis_masked_select(self): 1687*da0073e9SAndroid Build Coastguard Worker input_str = """graph(%0 : Float(), 1688*da0073e9SAndroid Build Coastguard Worker %1 : Bool()): 1689*da0073e9SAndroid Build Coastguard Worker # CHECK: Float(*, requires_grad=0, device=cpu) = aten::masked_select 1690*da0073e9SAndroid Build Coastguard Worker %2 : Tensor = aten::masked_select(%0, %1) # test/test_jit.py:15261:0 1691*da0073e9SAndroid Build Coastguard Worker return (%2)""" 1692*da0073e9SAndroid Build Coastguard Worker graph = parse_ir(input_str) 1693*da0073e9SAndroid Build Coastguard Worker x = torch.ones(1, dtype=torch.float32)[0] 1694*da0073e9SAndroid Build Coastguard Worker mask = x.ge(0.5) 1695*da0073e9SAndroid Build Coastguard Worker torch._C._jit_pass_complete_shape_analysis(graph, (x, mask), False) 1696*da0073e9SAndroid Build Coastguard Worker FileCheck().run(input_str, graph) 1697*da0073e9SAndroid Build Coastguard Worker 1698*da0073e9SAndroid Build Coastguard Worker # TODO: update verify to work with GraphExecutors 1699*da0073e9SAndroid Build Coastguard Worker @unittest.skip("verify needs to be updated to work with GraphExecutors") 1700*da0073e9SAndroid Build Coastguard Worker def test_verify(self): 1701*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([0.4], requires_grad=True) 1702*da0073e9SAndroid Build Coastguard Worker y = torch.tensor([0.7], requires_grad=True) 1703*da0073e9SAndroid Build Coastguard Worker 1704*da0073e9SAndroid Build Coastguard Worker @torch.jit.compile 1705*da0073e9SAndroid Build Coastguard Worker def f(x, y): 1706*da0073e9SAndroid Build Coastguard Worker z = torch.sigmoid(x * (x + y)) 1707*da0073e9SAndroid Build Coastguard Worker w = torch.abs(x * x * x + y) + Variable(torch.ones(1)) 1708*da0073e9SAndroid Build Coastguard Worker return z, w 1709*da0073e9SAndroid Build Coastguard Worker 1710*da0073e9SAndroid Build Coastguard Worker torch.jit.verify(f, (x, y), loss_fn=lambda z, w: z * w, devices=[]) 1711*da0073e9SAndroid Build Coastguard Worker 1712*da0073e9SAndroid Build Coastguard Worker # TODO: adapt to a GraphExecutor test 1713*da0073e9SAndroid Build Coastguard Worker @unittest.skip("Need to instrument GraphExecutors a bit more") 1714*da0073e9SAndroid Build Coastguard Worker def test_flags(self): 1715*da0073e9SAndroid Build Coastguard Worker x, y = torch.randn(2, 2) 1716*da0073e9SAndroid Build Coastguard Worker y = Variable(torch.randn(2, 2)) 1717*da0073e9SAndroid Build Coastguard Worker 1718*da0073e9SAndroid Build Coastguard Worker @torch.jit.compile 1719*da0073e9SAndroid Build Coastguard Worker def fn(x, y): 1720*da0073e9SAndroid Build Coastguard Worker return (x * x + y * y + x * y).sum() 1721*da0073e9SAndroid Build Coastguard Worker 1722*da0073e9SAndroid Build Coastguard Worker grads = {} 1723*da0073e9SAndroid Build Coastguard Worker for rx, ry in product((True, False), repeat=2): 1724*da0073e9SAndroid Build Coastguard Worker x.requires_grad = rx 1725*da0073e9SAndroid Build Coastguard Worker y.requires_grad = ry 1726*da0073e9SAndroid Build Coastguard Worker 1727*da0073e9SAndroid Build Coastguard Worker self.assertFalse(fn.has_trace_for(x, y)) 1728*da0073e9SAndroid Build Coastguard Worker out = fn(x, y) 1729*da0073e9SAndroid Build Coastguard Worker 1730*da0073e9SAndroid Build Coastguard Worker self.assertFalse(fn.has_trace_for(x, y)) 1731*da0073e9SAndroid Build Coastguard Worker for v, name, compute in [(x, 'x', rx), (y, 'y', ry)]: 1732*da0073e9SAndroid Build Coastguard Worker if not compute: 1733*da0073e9SAndroid Build Coastguard Worker continue 1734*da0073e9SAndroid Build Coastguard Worker grad_v, = torch.autograd.grad(out, v, retain_graph=True) 1735*da0073e9SAndroid Build Coastguard Worker expected_grad = grads.setdefault(name, grad_v) 1736*da0073e9SAndroid Build Coastguard Worker self.assertEqual(grad_v, expected_grad) 1737*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn.has_trace_for(x, y), rx or ry) 1738*da0073e9SAndroid Build Coastguard Worker 1739*da0073e9SAndroid Build Coastguard Worker def test_python_ir(self): 1740*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([0.4], requires_grad=True) 1741*da0073e9SAndroid Build Coastguard Worker y = torch.tensor([0.7], requires_grad=True) 1742*da0073e9SAndroid Build Coastguard Worker 1743*da0073e9SAndroid Build Coastguard Worker def doit(x, y): 1744*da0073e9SAndroid Build Coastguard Worker return torch.sigmoid(torch.tanh(x * (x + y))) 1745*da0073e9SAndroid Build Coastguard Worker 1746*da0073e9SAndroid Build Coastguard Worker g, _ = torch.jit._get_trace_graph(doit, (x, y)) 1747*da0073e9SAndroid Build Coastguard Worker self.run_pass('dce', g) 1748*da0073e9SAndroid Build Coastguard Worker self.run_pass('canonicalize', g) 1749*da0073e9SAndroid Build Coastguard Worker g2 = torch._C.Graph() 1750*da0073e9SAndroid Build Coastguard Worker g_to_g2 = {} 1751*da0073e9SAndroid Build Coastguard Worker for node in g.inputs(): 1752*da0073e9SAndroid Build Coastguard Worker g_to_g2[node] = g2.addInput() 1753*da0073e9SAndroid Build Coastguard Worker for node in g.nodes(): 1754*da0073e9SAndroid Build Coastguard Worker n_ = g2.createClone(node, lambda x: g_to_g2[x]) 1755*da0073e9SAndroid Build Coastguard Worker g2.appendNode(n_) 1756*da0073e9SAndroid Build Coastguard Worker for o, no in zip(node.outputs(), n_.outputs()): 1757*da0073e9SAndroid Build Coastguard Worker g_to_g2[o] = no 1758*da0073e9SAndroid Build Coastguard Worker 1759*da0073e9SAndroid Build Coastguard Worker for node in g.outputs(): 1760*da0073e9SAndroid Build Coastguard Worker g2.registerOutput(g_to_g2[node]) 1761*da0073e9SAndroid Build Coastguard Worker 1762*da0073e9SAndroid Build Coastguard Worker t_node = g2.create("prim::TensorTest").t_("a", torch.ones([2, 2])) 1763*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t_node.attributeNames(), ["a"]) 1764*da0073e9SAndroid Build Coastguard Worker g2.appendNode(t_node) 1765*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.equal(torch.ones(2, 2), t_node.t("a"))) 1766*da0073e9SAndroid Build Coastguard Worker for node in g.nodes(): 1767*da0073e9SAndroid Build Coastguard Worker self.assertTrue(g2.findNode(node.kind()) is not None) 1768*da0073e9SAndroid Build Coastguard Worker 1769*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_SANDCASTLE, "gtest runs these in sandcastle") 1770*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(RUN_CUDA, "covered by test_cpp_cuda") 1771*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not torch._C._jit_has_cpp_tests(), "Tests were not built, use BUILD_TEST=1") 1772*da0073e9SAndroid Build Coastguard Worker def test_cpp(self): 1773*da0073e9SAndroid Build Coastguard Worker from cpp.jit import tests_setup 1774*da0073e9SAndroid Build Coastguard Worker tests_setup.setup() 1775*da0073e9SAndroid Build Coastguard Worker torch._C._jit_run_cpp_tests() 1776*da0073e9SAndroid Build Coastguard Worker tests_setup.shutdown() 1777*da0073e9SAndroid Build Coastguard Worker 1778*da0073e9SAndroid Build Coastguard Worker def test_batchnorm(self): 1779*da0073e9SAndroid Build Coastguard Worker x = torch.ones(2, 2, 2, 2) 1780*da0073e9SAndroid Build Coastguard Worker g, outputs, inputs = torch.jit._get_trace_graph(nn.BatchNorm2d(2), x, 1781*da0073e9SAndroid Build Coastguard Worker _force_outplace=True, return_inputs=True) 1782*da0073e9SAndroid Build Coastguard Worker m = self.createFunctionFromGraph(g) 1783*da0073e9SAndroid Build Coastguard Worker self.assertEqual(outputs, m(*inputs)) 1784*da0073e9SAndroid Build Coastguard Worker 1785*da0073e9SAndroid Build Coastguard Worker def test_dropout(self): 1786*da0073e9SAndroid Build Coastguard Worker x = torch.ones(2, 2) 1787*da0073e9SAndroid Build Coastguard Worker with torch.random.fork_rng(devices=[]): 1788*da0073e9SAndroid Build Coastguard Worker g, outputs, inputs = torch.jit._get_trace_graph(nn.Dropout(0.6), x, return_inputs=True) 1789*da0073e9SAndroid Build Coastguard Worker with torch.random.fork_rng(devices=[]): 1790*da0073e9SAndroid Build Coastguard Worker m = self.createFunctionFromGraph(g) 1791*da0073e9SAndroid Build Coastguard Worker self.assertEqual(outputs, m(*inputs)) 1792*da0073e9SAndroid Build Coastguard Worker 1793*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not RUN_CUDA, "test requires CUDA") 1794*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "skip if profiling isn't enabled") 1795*da0073e9SAndroid Build Coastguard Worker def test_native_dropout_corner_case(self): 1796*da0073e9SAndroid Build Coastguard Worker with disable_autodiff_subgraph_inlining(): 1797*da0073e9SAndroid Build Coastguard Worker def t(x, p: float, t: bool): 1798*da0073e9SAndroid Build Coastguard Worker o = torch.dropout(x, p, t) 1799*da0073e9SAndroid Build Coastguard Worker return o 1800*da0073e9SAndroid Build Coastguard Worker 1801*da0073e9SAndroid Build Coastguard Worker jit_t = torch.jit.script(t) 1802*da0073e9SAndroid Build Coastguard Worker x = torch.randn(5).requires_grad_() 1803*da0073e9SAndroid Build Coastguard Worker FileCheck().check("prim::DifferentiableGraph").run(jit_t.graph_for(x, 1.0, True, profile_and_replay=True)) 1804*da0073e9SAndroid Build Coastguard Worker 1805*da0073e9SAndroid Build Coastguard Worker for train in [True, False]: 1806*da0073e9SAndroid Build Coastguard Worker for p in [0.0, 1.0]: 1807*da0073e9SAndroid Build Coastguard Worker for device in ["cuda", "cpu"]: 1808*da0073e9SAndroid Build Coastguard Worker x = torch.randn(5).to(device=device).requires_grad_() 1809*da0073e9SAndroid Build Coastguard Worker x_ref = x.detach().requires_grad_() 1810*da0073e9SAndroid Build Coastguard Worker o = jit_t(x, p, train) 1811*da0073e9SAndroid Build Coastguard Worker o_ref = t(x_ref, p, train) 1812*da0073e9SAndroid Build Coastguard Worker o.sum().backward() 1813*da0073e9SAndroid Build Coastguard Worker o_ref.sum().backward() 1814*da0073e9SAndroid Build Coastguard Worker assert o.equal(o_ref) 1815*da0073e9SAndroid Build Coastguard Worker assert x.grad.equal(x_ref.grad) 1816*da0073e9SAndroid Build Coastguard Worker 1817*da0073e9SAndroid Build Coastguard Worker @slowTest 1818*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, 'Testing differentiable graph') 1819*da0073e9SAndroid Build Coastguard Worker def test_dropout_module_requires_grad(self): 1820*da0073e9SAndroid Build Coastguard Worker with enable_profiling_mode_for_profiling_tests(): 1821*da0073e9SAndroid Build Coastguard Worker class MyModule(torch.nn.Module): 1822*da0073e9SAndroid Build Coastguard Worker def __init__(self, M): 1823*da0073e9SAndroid Build Coastguard Worker super().__init__() 1824*da0073e9SAndroid Build Coastguard Worker self.dropout = torch.nn.Dropout(0.5) 1825*da0073e9SAndroid Build Coastguard Worker self.linear = torch.nn.Linear(M, M) 1826*da0073e9SAndroid Build Coastguard Worker 1827*da0073e9SAndroid Build Coastguard Worker def forward(self, input): 1828*da0073e9SAndroid Build Coastguard Worker input = self.dropout(input) 1829*da0073e9SAndroid Build Coastguard Worker output = self.linear(input) 1830*da0073e9SAndroid Build Coastguard Worker return output 1831*da0073e9SAndroid Build Coastguard Worker 1832*da0073e9SAndroid Build Coastguard Worker def profile(func, X): 1833*da0073e9SAndroid Build Coastguard Worker with torch.autograd.profiler.profile() as prof: 1834*da0073e9SAndroid Build Coastguard Worker func(X) 1835*da0073e9SAndroid Build Coastguard Worker return [e.name for e in prof.function_events] 1836*da0073e9SAndroid Build Coastguard Worker 1837*da0073e9SAndroid Build Coastguard Worker M = 1000 1838*da0073e9SAndroid Build Coastguard Worker scripted = torch.jit.script(MyModule(M)) 1839*da0073e9SAndroid Build Coastguard Worker # To reduce confusion about expected behaviors: 1840*da0073e9SAndroid Build Coastguard Worker # requires_grad controls whether dropout is symbolically differentiated. 1841*da0073e9SAndroid Build Coastguard Worker # training controls whether bernoulli_ is called inside symbolic differentiation of dropout. 1842*da0073e9SAndroid Build Coastguard Worker # * When requires_grad == training, the expected behaviors are obvious. 1843*da0073e9SAndroid Build Coastguard Worker # * When requires_grad=True and training=False, bernoulli_ might still show up in the graph. 1844*da0073e9SAndroid Build Coastguard Worker # But it's in a branch that's not called. That's why we have separate checks for autograd 1845*da0073e9SAndroid Build Coastguard Worker # profiler to make sure it's not run. 1846*da0073e9SAndroid Build Coastguard Worker # * When requires_grad=False and training=True, bernoulli_ must be run since it's the expected 1847*da0073e9SAndroid Build Coastguard Worker # behavior for the dropout layer in training mode. It's independent of whether graph requires 1848*da0073e9SAndroid Build Coastguard Worker # gradient. In fact bernoulli_ comes from autograd instead of autodiff in this case. 1849*da0073e9SAndroid Build Coastguard Worker for training in (True, False): 1850*da0073e9SAndroid Build Coastguard Worker if training: 1851*da0073e9SAndroid Build Coastguard Worker scripted.train() 1852*da0073e9SAndroid Build Coastguard Worker else: 1853*da0073e9SAndroid Build Coastguard Worker scripted.eval() 1854*da0073e9SAndroid Build Coastguard Worker for requires_grad in (True, False): 1855*da0073e9SAndroid Build Coastguard Worker X = torch.randn(M, M, requires_grad=requires_grad) 1856*da0073e9SAndroid Build Coastguard Worker if requires_grad: 1857*da0073e9SAndroid Build Coastguard Worker FileCheck().check("aten::native_dropout").run(scripted.graph_for(X, profile_and_replay=True)) 1858*da0073e9SAndroid Build Coastguard Worker self.assertEqual(training, 'aten::bernoulli_' in profile(scripted, X)) 1859*da0073e9SAndroid Build Coastguard Worker 1860*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(GRAPH_EXECUTOR == ProfilingMode.SIMPLE, 'Testing differentiable graph') 1861*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("Torchdynamo cannot correctly handle profiler.profile calls") 1862*da0073e9SAndroid Build Coastguard Worker def test_dropout_func_requires_grad(self): 1863*da0073e9SAndroid Build Coastguard Worker def dropout_training(input): 1864*da0073e9SAndroid Build Coastguard Worker return F.dropout(input, 0.5, training=True) 1865*da0073e9SAndroid Build Coastguard Worker 1866*da0073e9SAndroid Build Coastguard Worker def dropout_eval(input): 1867*da0073e9SAndroid Build Coastguard Worker return F.dropout(input, 0.5, training=False) 1868*da0073e9SAndroid Build Coastguard Worker 1869*da0073e9SAndroid Build Coastguard Worker def profile(func, X): 1870*da0073e9SAndroid Build Coastguard Worker with torch.autograd.profiler.profile() as prof: 1871*da0073e9SAndroid Build Coastguard Worker func(X) 1872*da0073e9SAndroid Build Coastguard Worker return [e.name for e in prof.function_events] 1873*da0073e9SAndroid Build Coastguard Worker 1874*da0073e9SAndroid Build Coastguard Worker M = 1000 1875*da0073e9SAndroid Build Coastguard Worker scripted_training = torch.jit.script(dropout_training) 1876*da0073e9SAndroid Build Coastguard Worker scripted_eval = torch.jit.script(dropout_eval) 1877*da0073e9SAndroid Build Coastguard Worker # See comments in test_dropout_module_requires_grad. 1878*da0073e9SAndroid Build Coastguard Worker with disable_autodiff_subgraph_inlining(): 1879*da0073e9SAndroid Build Coastguard Worker for requires_grad in (True, False): 1880*da0073e9SAndroid Build Coastguard Worker X = torch.randn(M, M, requires_grad=requires_grad) 1881*da0073e9SAndroid Build Coastguard Worker if requires_grad: 1882*da0073e9SAndroid Build Coastguard Worker FileCheck().check("aten::native_dropout").run(scripted_training.graph_for(X, profile_and_replay=True)) 1883*da0073e9SAndroid Build Coastguard Worker self.assertIn('aten::bernoulli_', profile(scripted_training, X)) 1884*da0073e9SAndroid Build Coastguard Worker self.assertNotIn('aten::bernoulli_', profile(scripted_eval, X)) 1885*da0073e9SAndroid Build Coastguard Worker 1886*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not RUN_CUDA, "test_dropout_cuda require CUDA") 1887*da0073e9SAndroid Build Coastguard Worker def test_dropout_cuda(self): 1888*da0073e9SAndroid Build Coastguard Worker # Dropout AD is dispatched to _fused_dropout in CUDA case, 1889*da0073e9SAndroid Build Coastguard Worker # which is not included in TestJitGeneratedFunctional 1890*da0073e9SAndroid Build Coastguard Worker def _zero_rate(t): 1891*da0073e9SAndroid Build Coastguard Worker return torch.true_divide((t == 0).sum(), t.numel()) 1892*da0073e9SAndroid Build Coastguard Worker 1893*da0073e9SAndroid Build Coastguard Worker x = torch.ones(1000, 1000).cuda().requires_grad_() 1894*da0073e9SAndroid Build Coastguard Worker 1895*da0073e9SAndroid Build Coastguard Worker with enable_profiling_mode_for_profiling_tests(): 1896*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 1897*da0073e9SAndroid Build Coastguard Worker def func(x): 1898*da0073e9SAndroid Build Coastguard Worker return torch.nn.functional.dropout(x) 1899*da0073e9SAndroid Build Coastguard Worker 1900*da0073e9SAndroid Build Coastguard Worker with freeze_rng_state(): 1901*da0073e9SAndroid Build Coastguard Worker out_ref = torch.nn.functional.dropout(x) 1902*da0073e9SAndroid Build Coastguard Worker grad_ref = torch.autograd.grad(out_ref.sum(), x) 1903*da0073e9SAndroid Build Coastguard Worker 1904*da0073e9SAndroid Build Coastguard Worker with freeze_rng_state(): 1905*da0073e9SAndroid Build Coastguard Worker out = func(x) 1906*da0073e9SAndroid Build Coastguard Worker grad = torch.autograd.grad(out.sum(), x) 1907*da0073e9SAndroid Build Coastguard Worker 1908*da0073e9SAndroid Build Coastguard Worker # TODO(#40882): previously we assert exact matches between eager and JIT result: 1909*da0073e9SAndroid Build Coastguard Worker # self.assertEqual(out, out_ref) 1910*da0073e9SAndroid Build Coastguard Worker # self.assertEqual(grad, grad_ref) 1911*da0073e9SAndroid Build Coastguard Worker # This test was disabled during legacy -> profiling executor transition. 1912*da0073e9SAndroid Build Coastguard Worker # Currently JIT fused results doesn't match eager result exactly due to some changes merged in between. 1913*da0073e9SAndroid Build Coastguard Worker # We temporarily only check statstical difference but it should be reverted once the issue is fixed. 1914*da0073e9SAndroid Build Coastguard Worker self.assertEqual(_zero_rate(out), _zero_rate(out_ref), rtol=1e-3, atol=1e-4) 1915*da0073e9SAndroid Build Coastguard Worker self.assertEqual(_zero_rate(grad[0]), _zero_rate(grad_ref[0]), rtol=1e-3, atol=1e-4) 1916*da0073e9SAndroid Build Coastguard Worker 1917*da0073e9SAndroid Build Coastguard Worker def test_torch_ops_overloaded(self): 1918*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "failed to match any schema"): 1919*da0073e9SAndroid Build Coastguard Worker torch.ops.aten.add("a", 1) 1920*da0073e9SAndroid Build Coastguard Worker self.assertEqual("ab", torch.ops.aten.add("a", "b")) 1921*da0073e9SAndroid Build Coastguard Worker a, b = torch.rand(3, 4), torch.rand(3, 4) 1922*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a + b, torch.ops.aten.add(a, b)) 1923*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a + 1, torch.ops.aten.add(a, 1)) 1924*da0073e9SAndroid Build Coastguard Worker 1925*da0073e9SAndroid Build Coastguard Worker def test_torch_ops_kwonly(self): 1926*da0073e9SAndroid Build Coastguard Worker a, b = torch.rand(3, 4), torch.rand(3, 4) 1927*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "positional argument"): 1928*da0073e9SAndroid Build Coastguard Worker torch.ops.aten.add(a, b, 2) 1929*da0073e9SAndroid Build Coastguard Worker # h/t Chillee for this ambiguous case 1930*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.prod(1), torch.ops.aten.prod(a, 1)) 1931*da0073e9SAndroid Build Coastguard Worker 1932*da0073e9SAndroid Build Coastguard Worker def test_torch_complex(self): 1933*da0073e9SAndroid Build Coastguard Worker def fn(real, img): 1934*da0073e9SAndroid Build Coastguard Worker return torch.complex(real, img) 1935*da0073e9SAndroid Build Coastguard Worker 1936*da0073e9SAndroid Build Coastguard Worker def fn_out(real, img, out): 1937*da0073e9SAndroid Build Coastguard Worker return torch.complex(real, img, out=out) 1938*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, (torch.rand(3, 4), torch.rand(3, 4), )) 1939*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, (torch.ones(5, 1, 4), torch.ones(5, 1, 4), )) 1940*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, (torch.zeros(1, 6), torch.ones(6, 1), )) 1941*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, (torch.zeros(1, 6), torch.zeros(6, 1), )) 1942*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, (torch.empty(3, 4), torch.empty(3, 4), )) 1943*da0073e9SAndroid Build Coastguard Worker 1944*da0073e9SAndroid Build Coastguard Worker real = torch.tensor([1, 2], dtype=torch.float32) 1945*da0073e9SAndroid Build Coastguard Worker img = torch.tensor([3, 4], dtype=torch.float32) 1946*da0073e9SAndroid Build Coastguard Worker out = torch.empty([3, 4], dtype=torch.complex64) 1947*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn_out, (real, img, out, )) 1948*da0073e9SAndroid Build Coastguard Worker 1949*da0073e9SAndroid Build Coastguard Worker real = torch.tensor([5, 2], dtype=torch.float64) 1950*da0073e9SAndroid Build Coastguard Worker img = torch.tensor([3, 4], dtype=torch.float64) 1951*da0073e9SAndroid Build Coastguard Worker out = torch.empty([5, 2], dtype=torch.complex128) 1952*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn_out, (real, img, out, )) 1953*da0073e9SAndroid Build Coastguard Worker 1954*da0073e9SAndroid Build Coastguard Worker real = torch.ones([1, 2]) 1955*da0073e9SAndroid Build Coastguard Worker img = torch.ones([1, 2]) 1956*da0073e9SAndroid Build Coastguard Worker out = torch.empty([1, 2], dtype=torch.complex64) 1957*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn_out, (real, img, out, )) 1958*da0073e9SAndroid Build Coastguard Worker 1959*da0073e9SAndroid Build Coastguard Worker real = torch.ones([3, 8, 7]) 1960*da0073e9SAndroid Build Coastguard Worker img = torch.ones([3, 8, 7]) 1961*da0073e9SAndroid Build Coastguard Worker out = torch.empty([3, 8, 7], dtype=torch.complex64) 1962*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn_out, (real, img, out, )) 1963*da0073e9SAndroid Build Coastguard Worker 1964*da0073e9SAndroid Build Coastguard Worker real = torch.empty([3, 2, 6]) 1965*da0073e9SAndroid Build Coastguard Worker img = torch.empty([3, 2, 6]) 1966*da0073e9SAndroid Build Coastguard Worker out = torch.empty([3, 2, 6], dtype=torch.complex64) 1967*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn_out, (real, img, out, )) 1968*da0073e9SAndroid Build Coastguard Worker 1969*da0073e9SAndroid Build Coastguard Worker real = torch.zeros([1, 3]) 1970*da0073e9SAndroid Build Coastguard Worker img = torch.empty([3, 1]) 1971*da0073e9SAndroid Build Coastguard Worker out = torch.empty([3, 3], dtype=torch.complex64) 1972*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn_out, (real, img, out, )) 1973*da0073e9SAndroid Build Coastguard Worker 1974*da0073e9SAndroid Build Coastguard Worker real = torch.ones([2, 5]) 1975*da0073e9SAndroid Build Coastguard Worker img = torch.empty([2, 1]) 1976*da0073e9SAndroid Build Coastguard Worker out = torch.empty([2, 5], dtype=torch.complex64) 1977*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn_out, (real, img, out, )) 1978*da0073e9SAndroid Build Coastguard Worker 1979*da0073e9SAndroid Build Coastguard Worker real = torch.ones([2, 5]) 1980*da0073e9SAndroid Build Coastguard Worker img = torch.zeros([2, 1]) 1981*da0073e9SAndroid Build Coastguard Worker out = torch.empty([2, 5], dtype=torch.complex64) 1982*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn_out, (real, img, out, )) 1983*da0073e9SAndroid Build Coastguard Worker 1984*da0073e9SAndroid Build Coastguard Worker def test_einsum(self): 1985*da0073e9SAndroid Build Coastguard Worker def check(fn, jitted, *args): 1986*da0073e9SAndroid Build Coastguard Worker self.assertGraphContains(jitted.graph, kind='aten::einsum') 1987*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn(*args), jitted(*args)) 1988*da0073e9SAndroid Build Coastguard Worker 1989*da0073e9SAndroid Build Coastguard Worker def equation_format(x, y): 1990*da0073e9SAndroid Build Coastguard Worker return torch.einsum('i,j->ij', (x, y)) 1991*da0073e9SAndroid Build Coastguard Worker 1992*da0073e9SAndroid Build Coastguard Worker def equation_format_varargs(x, y): 1993*da0073e9SAndroid Build Coastguard Worker return torch.einsum('i,j->ij', x, y) 1994*da0073e9SAndroid Build Coastguard Worker 1995*da0073e9SAndroid Build Coastguard Worker def sublist_format(x, y): 1996*da0073e9SAndroid Build Coastguard Worker return torch.einsum(x, [0], y, [1], [0, 1]) 1997*da0073e9SAndroid Build Coastguard Worker 1998*da0073e9SAndroid Build Coastguard Worker x = make_tensor((5,), dtype=torch.float32, device="cpu") 1999*da0073e9SAndroid Build Coastguard Worker y = make_tensor((10,), dtype=torch.float32, device="cpu") 2000*da0073e9SAndroid Build Coastguard Worker 2001*da0073e9SAndroid Build Coastguard Worker for fn in [equation_format, equation_format_varargs, sublist_format]: 2002*da0073e9SAndroid Build Coastguard Worker check(fn, torch.jit.script(fn), x, y) 2003*da0073e9SAndroid Build Coastguard Worker check(fn, torch.jit.trace(fn, (x, y)), x, y) 2004*da0073e9SAndroid Build Coastguard Worker 2005*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("TorchDynamo fails with unknown reason") 2006*da0073e9SAndroid Build Coastguard Worker def test_python_ivalue(self): 2007*da0073e9SAndroid Build Coastguard Worker # Test if pure python object can be hold as IValue and conversion 2008*da0073e9SAndroid Build Coastguard Worker # between IValue and PyObject are correct 2009*da0073e9SAndroid Build Coastguard Worker # test for numpy object 2010*da0073e9SAndroid Build Coastguard Worker py_array = np.arange(15) 2011*da0073e9SAndroid Build Coastguard Worker ret_py_obj = torch._C._ivalue_debug_python_object(py_array) 2012*da0073e9SAndroid Build Coastguard Worker self.assertEqual(py_array, ret_py_obj) 2013*da0073e9SAndroid Build Coastguard Worker 2014*da0073e9SAndroid Build Coastguard Worker # test for function object 2015*da0073e9SAndroid Build Coastguard Worker ret_py_obj = torch._C._ivalue_debug_python_object(F.relu) 2016*da0073e9SAndroid Build Coastguard Worker self.assertEqual(F.relu, ret_py_obj) 2017*da0073e9SAndroid Build Coastguard Worker 2018*da0073e9SAndroid Build Coastguard Worker # test for memory management 2019*da0073e9SAndroid Build Coastguard Worker # we need to ensure IValue correctly call incref/decref to avoid 2020*da0073e9SAndroid Build Coastguard Worker # dangling behavior and potential memory leaks during conversions 2021*da0073e9SAndroid Build Coastguard Worker def test_func_scope_helper(inp): 2022*da0073e9SAndroid Build Coastguard Worker # create a scope and do the conversion -> ivalue -> pyobject 2023*da0073e9SAndroid Build Coastguard Worker # this func return a new pyobject that refcount + 1 2024*da0073e9SAndroid Build Coastguard Worker inp_refcount = sys.getrefcount(inp) 2025*da0073e9SAndroid Build Coastguard Worker ivalue_holder = torch._C._ivalue_debug_python_object(inp) 2026*da0073e9SAndroid Build Coastguard Worker self.assertEqual(inp_refcount + 1, sys.getrefcount(ivalue_holder)) 2027*da0073e9SAndroid Build Coastguard Worker return ivalue_holder + 1 2028*da0073e9SAndroid Build Coastguard Worker 2029*da0073e9SAndroid Build Coastguard Worker test_input = 2200 2030*da0073e9SAndroid Build Coastguard Worker before_count = sys.getrefcount(test_input) 2031*da0073e9SAndroid Build Coastguard Worker test_func_scope_helper(test_input) 2032*da0073e9SAndroid Build Coastguard Worker after_count = sys.getrefcount(test_input) 2033*da0073e9SAndroid Build Coastguard Worker 2034*da0073e9SAndroid Build Coastguard Worker # after the test_func_scope_helper_call, the refcount of 2035*da0073e9SAndroid Build Coastguard Worker # test_input should be equal to the original refcount 2036*da0073e9SAndroid Build Coastguard Worker # otherwise we get either dangling pointer or memory leak! 2037*da0073e9SAndroid Build Coastguard Worker self.assertEqual(before_count, after_count) 2038*da0073e9SAndroid Build Coastguard Worker 2039*da0073e9SAndroid Build Coastguard Worker def test_decompose_addmm(self): 2040*da0073e9SAndroid Build Coastguard Worker def does_decompose(): 2041*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 2042*da0073e9SAndroid Build Coastguard Worker def addmm(mat, mat1, mat2): 2043*da0073e9SAndroid Build Coastguard Worker a = mat.addmm(mat1, mat2) 2044*da0073e9SAndroid Build Coastguard Worker b = mat.addmm(mat1, mat2, alpha=1.0, beta=1.0) 2045*da0073e9SAndroid Build Coastguard Worker return a + b 2046*da0073e9SAndroid Build Coastguard Worker 2047*da0073e9SAndroid Build Coastguard Worker mat = torch.randn(2, 2) 2048*da0073e9SAndroid Build Coastguard Worker mat1 = torch.randn(2, 4) 2049*da0073e9SAndroid Build Coastguard Worker mat2 = torch.randn(4, 2) 2050*da0073e9SAndroid Build Coastguard Worker 2051*da0073e9SAndroid Build Coastguard Worker out_ref = addmm(mat, mat1, mat2) 2052*da0073e9SAndroid Build Coastguard Worker self.run_pass('decompose_ops', addmm.graph) 2053*da0073e9SAndroid Build Coastguard Worker out_test = addmm(mat, mat1, mat2) 2054*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_ref, out_test) 2055*da0073e9SAndroid Build Coastguard Worker FileCheck().check_not("addmm").run(str(addmm.graph)) 2056*da0073e9SAndroid Build Coastguard Worker 2057*da0073e9SAndroid Build Coastguard Worker def doesnt_decompose(): 2058*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 2059*da0073e9SAndroid Build Coastguard Worker def addmm(mat, mat1, mat2, alpha, beta): 2060*da0073e9SAndroid Build Coastguard Worker a = mat.addmm(mat1, mat2, alpha=4.20, beta=2.0) 2061*da0073e9SAndroid Build Coastguard Worker b = mat.addmm(mat1, mat2, alpha=int(alpha), beta=int(beta)) 2062*da0073e9SAndroid Build Coastguard Worker 2063*da0073e9SAndroid Build Coastguard Worker return a + b 2064*da0073e9SAndroid Build Coastguard Worker 2065*da0073e9SAndroid Build Coastguard Worker orig = str(addmm.graph) 2066*da0073e9SAndroid Build Coastguard Worker self.run_pass('decompose_ops', addmm.graph) 2067*da0073e9SAndroid Build Coastguard Worker self.assertTrue(orig == str(addmm.graph)) 2068*da0073e9SAndroid Build Coastguard Worker 2069*da0073e9SAndroid Build Coastguard Worker does_decompose() 2070*da0073e9SAndroid Build Coastguard Worker doesnt_decompose() 2071*da0073e9SAndroid Build Coastguard Worker 2072*da0073e9SAndroid Build Coastguard Worker @suppress_warnings 2073*da0073e9SAndroid Build Coastguard Worker def test_sparse_tensors(self): 2074*da0073e9SAndroid Build Coastguard Worker @torch.jit.ignore 2075*da0073e9SAndroid Build Coastguard Worker def get_sparse(): 2076*da0073e9SAndroid Build Coastguard Worker return torch.sparse_coo_tensor((2, 3), dtype=torch.float32) 2077*da0073e9SAndroid Build Coastguard Worker 2078*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 2079*da0073e9SAndroid Build Coastguard Worker def test_is_sparse(input): 2080*da0073e9SAndroid Build Coastguard Worker # type: (Tensor) -> bool 2081*da0073e9SAndroid Build Coastguard Worker return input.is_sparse 2082*da0073e9SAndroid Build Coastguard Worker 2083*da0073e9SAndroid Build Coastguard Worker script_out_is_sparse = test_is_sparse(get_sparse()) 2084*da0073e9SAndroid Build Coastguard Worker script_out_is_dense = test_is_sparse(torch.randn(2, 3)) 2085*da0073e9SAndroid Build Coastguard Worker self.assertEqual(script_out_is_sparse, True) 2086*da0073e9SAndroid Build Coastguard Worker self.assertEqual(script_out_is_dense, False) 2087*da0073e9SAndroid Build Coastguard Worker 2088*da0073e9SAndroid Build Coastguard Worker def test_basic_sparse(input): 2089*da0073e9SAndroid Build Coastguard Worker output = get_sparse() 2090*da0073e9SAndroid Build Coastguard Worker return output, input 2091*da0073e9SAndroid Build Coastguard Worker 2092*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_basic_sparse, (get_sparse(),)) 2093*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_basic_sparse, (torch.tensor([1]),)) 2094*da0073e9SAndroid Build Coastguard Worker 2095*da0073e9SAndroid Build Coastguard Worker def test_sparse_sum(input): 2096*da0073e9SAndroid Build Coastguard Worker return torch.sparse.sum(input) 2097*da0073e9SAndroid Build Coastguard Worker 2098*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_sparse_sum, (get_sparse(),)) 2099*da0073e9SAndroid Build Coastguard Worker 2100*da0073e9SAndroid Build Coastguard Worker def test_sparse_mm(input1, input2): 2101*da0073e9SAndroid Build Coastguard Worker return torch.sparse.mm(input1, input2) 2102*da0073e9SAndroid Build Coastguard Worker 2103*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_sparse_mm, (get_sparse(), torch.randn(3, 4))) 2104*da0073e9SAndroid Build Coastguard Worker 2105*da0073e9SAndroid Build Coastguard Worker def test_sparse_addmm(input, input1, input2): 2106*da0073e9SAndroid Build Coastguard Worker return torch.sparse.addmm(input, input1, input2) 2107*da0073e9SAndroid Build Coastguard Worker 2108*da0073e9SAndroid Build Coastguard Worker def test_sparse_addmm_alpha_beta(input, input1, input2): 2109*da0073e9SAndroid Build Coastguard Worker return torch.sparse.addmm(input, input1, input2, alpha=1.3, beta=1.5) 2110*da0073e9SAndroid Build Coastguard Worker 2111*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_sparse_addmm, (torch.randn(2, 4), get_sparse(), torch.randn(3, 4))) 2112*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_sparse_addmm_alpha_beta, (torch.randn(2, 4), get_sparse(), torch.randn(3, 4))) 2113*da0073e9SAndroid Build Coastguard Worker 2114*da0073e9SAndroid Build Coastguard Worker @suppress_warnings 2115*da0073e9SAndroid Build Coastguard Worker def test_sparse_csr_tensors(self): 2116*da0073e9SAndroid Build Coastguard Worker @torch.jit.ignore 2117*da0073e9SAndroid Build Coastguard Worker def get_sparse_csr(): 2118*da0073e9SAndroid Build Coastguard Worker return torch.randn(3, 3).to_sparse_csr() 2119*da0073e9SAndroid Build Coastguard Worker 2120*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 2121*da0073e9SAndroid Build Coastguard Worker def test_is_sparse_csr(input): 2122*da0073e9SAndroid Build Coastguard Worker # type: (Tensor) -> bool 2123*da0073e9SAndroid Build Coastguard Worker return input.is_sparse_csr 2124*da0073e9SAndroid Build Coastguard Worker 2125*da0073e9SAndroid Build Coastguard Worker script_out_is_sparse_csr = test_is_sparse_csr(get_sparse_csr()) 2126*da0073e9SAndroid Build Coastguard Worker script_out_is_dense_csr = test_is_sparse_csr(torch.randn(3, 3)) 2127*da0073e9SAndroid Build Coastguard Worker 2128*da0073e9SAndroid Build Coastguard Worker self.assertEqual(script_out_is_sparse_csr, True) 2129*da0073e9SAndroid Build Coastguard Worker self.assertEqual(script_out_is_dense_csr, False) 2130*da0073e9SAndroid Build Coastguard Worker 2131*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not RUN_CUDA, "requires CUDA") 2132*da0073e9SAndroid Build Coastguard Worker def test_device_not_equal(self): 2133*da0073e9SAndroid Build Coastguard Worker 2134*da0073e9SAndroid Build Coastguard Worker def compare_device(x: torch.device): 2135*da0073e9SAndroid Build Coastguard Worker return x != torch.device("cuda:0") 2136*da0073e9SAndroid Build Coastguard Worker 2137*da0073e9SAndroid Build Coastguard Worker def compare_two_device(x: torch.device, y: torch.device): 2138*da0073e9SAndroid Build Coastguard Worker return x != y 2139*da0073e9SAndroid Build Coastguard Worker 2140*da0073e9SAndroid Build Coastguard Worker self.checkScript(compare_device, (torch.device("cuda:0"),)) 2141*da0073e9SAndroid Build Coastguard Worker self.checkScript(compare_two_device, (torch.device("cuda:0"), torch.device("cuda:1"), )) 2142*da0073e9SAndroid Build Coastguard Worker 2143*da0073e9SAndroid Build Coastguard Worker def test_constant_prop_simple(self): 2144*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 2145*da0073e9SAndroid Build Coastguard Worker def constant_prop(input_int): 2146*da0073e9SAndroid Build Coastguard Worker # type: (int) -> int 2147*da0073e9SAndroid Build Coastguard Worker a = 2 * 3 2148*da0073e9SAndroid Build Coastguard Worker b = a + 2 2149*da0073e9SAndroid Build Coastguard Worker return b - input_int 2150*da0073e9SAndroid Build Coastguard Worker 2151*da0073e9SAndroid Build Coastguard Worker out_ref = constant_prop(2) 2152*da0073e9SAndroid Build Coastguard Worker self.run_pass('constant_propagation', constant_prop.graph) 2153*da0073e9SAndroid Build Coastguard Worker out_test = constant_prop(2) 2154*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_ref, out_test) 2155*da0073e9SAndroid Build Coastguard Worker graph_str = str(constant_prop.graph) 2156*da0073e9SAndroid Build Coastguard Worker self.assertTrue("aten::add" not in graph_str and "aten::mul" not in graph_str) 2157*da0073e9SAndroid Build Coastguard Worker const = constant_prop.graph.findNode("prim::Constant").output().toIValue() 2158*da0073e9SAndroid Build Coastguard Worker self.assertEqual(const, 8) 2159*da0073e9SAndroid Build Coastguard Worker 2160*da0073e9SAndroid Build Coastguard Worker def test_constant_prop_nested(self): 2161*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 2162*da0073e9SAndroid Build Coastguard Worker def constant_prop(a): 2163*da0073e9SAndroid Build Coastguard Worker b = 2 + 1 2164*da0073e9SAndroid Build Coastguard Worker if bool(a < 2): 2165*da0073e9SAndroid Build Coastguard Worker c = b + 2 2166*da0073e9SAndroid Build Coastguard Worker else: 2167*da0073e9SAndroid Build Coastguard Worker c = b - 2 2168*da0073e9SAndroid Build Coastguard Worker return c 2169*da0073e9SAndroid Build Coastguard Worker out_ref = constant_prop(torch.tensor(2)) 2170*da0073e9SAndroid Build Coastguard Worker self.run_pass('constant_propagation', constant_prop.graph) 2171*da0073e9SAndroid Build Coastguard Worker out_test = constant_prop(torch.tensor(2)) 2172*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_ref, out_test) 2173*da0073e9SAndroid Build Coastguard Worker if_node = constant_prop.graph.findNode("prim::If") 2174*da0073e9SAndroid Build Coastguard Worker for block in if_node.blocks(): 2175*da0073e9SAndroid Build Coastguard Worker for node in block.nodes(): 2176*da0073e9SAndroid Build Coastguard Worker self.assertTrue(node.kind() == "prim::Constant") 2177*da0073e9SAndroid Build Coastguard Worker 2178*da0073e9SAndroid Build Coastguard Worker def test_constant_prop_print(self): 2179*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 2180*da0073e9SAndroid Build Coastguard Worker def constant_prop(input_tensor): 2181*da0073e9SAndroid Build Coastguard Worker a = 2 * 3 2182*da0073e9SAndroid Build Coastguard Worker print(a) 2183*da0073e9SAndroid Build Coastguard Worker b = a + 2 2184*da0073e9SAndroid Build Coastguard Worker return b + input_tensor 2185*da0073e9SAndroid Build Coastguard Worker 2186*da0073e9SAndroid Build Coastguard Worker self.run_pass('constant_propagation', constant_prop.graph) 2187*da0073e9SAndroid Build Coastguard Worker graph = constant_prop.graph 2188*da0073e9SAndroid Build Coastguard Worker print_node = graph.findNode("prim::Print") 2189*da0073e9SAndroid Build Coastguard Worker self.assertTrue(print_node.input().toIValue() == 6) 2190*da0073e9SAndroid Build Coastguard Worker 2191*da0073e9SAndroid Build Coastguard Worker def test_constant_prop_rand(self): 2192*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 2193*da0073e9SAndroid Build Coastguard Worker def constant_prop(): 2194*da0073e9SAndroid Build Coastguard Worker a = torch.randn([3]) 2195*da0073e9SAndroid Build Coastguard Worker b = a + 2 2196*da0073e9SAndroid Build Coastguard Worker return b 2197*da0073e9SAndroid Build Coastguard Worker 2198*da0073e9SAndroid Build Coastguard Worker self.run_pass('constant_propagation', constant_prop.graph) 2199*da0073e9SAndroid Build Coastguard Worker self.assertTrue("aten::randn" in str(constant_prop.graph)) 2200*da0073e9SAndroid Build Coastguard Worker 2201*da0073e9SAndroid Build Coastguard Worker def test_constant_prop_none(self): 2202*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 2203*da0073e9SAndroid Build Coastguard Worker def typed_none(): 2204*da0073e9SAndroid Build Coastguard Worker # type: () -> Optional[int] 2205*da0073e9SAndroid Build Coastguard Worker return None 2206*da0073e9SAndroid Build Coastguard Worker 2207*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 2208*da0073e9SAndroid Build Coastguard Worker def constant_prop(): 2209*da0073e9SAndroid Build Coastguard Worker a = typed_none() 2210*da0073e9SAndroid Build Coastguard Worker b = typed_none() 2211*da0073e9SAndroid Build Coastguard Worker if (a is None and b is None): 2212*da0073e9SAndroid Build Coastguard Worker a = 2 2213*da0073e9SAndroid Build Coastguard Worker else: 2214*da0073e9SAndroid Build Coastguard Worker a = 1 2215*da0073e9SAndroid Build Coastguard Worker return a 2216*da0073e9SAndroid Build Coastguard Worker 2217*da0073e9SAndroid Build Coastguard Worker self.run_pass('constant_propagation', constant_prop.graph) 2218*da0073e9SAndroid Build Coastguard Worker FileCheck().check("prim::Constant").run(constant_prop.graph) 2219*da0073e9SAndroid Build Coastguard Worker 2220*da0073e9SAndroid Build Coastguard Worker def test_constant_prop_if_inline(self): 2221*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 2222*da0073e9SAndroid Build Coastguard Worker def constant_prop(): 2223*da0073e9SAndroid Build Coastguard Worker cond = True 2224*da0073e9SAndroid Build Coastguard Worker a = 1 2225*da0073e9SAndroid Build Coastguard Worker if cond: 2226*da0073e9SAndroid Build Coastguard Worker a = 1 * 2 2227*da0073e9SAndroid Build Coastguard Worker else: 2228*da0073e9SAndroid Build Coastguard Worker a = 1 // 0 2229*da0073e9SAndroid Build Coastguard Worker return a 2230*da0073e9SAndroid Build Coastguard Worker 2231*da0073e9SAndroid Build Coastguard Worker # testing that 1 // 0 error is not thrownn 2232*da0073e9SAndroid Build Coastguard Worker self.run_pass('constant_propagation', constant_prop.graph) 2233*da0073e9SAndroid Build Coastguard Worker 2234*da0073e9SAndroid Build Coastguard Worker def test_constant_prop_exception(self): 2235*da0073e9SAndroid Build Coastguard Worker # checking y = a[4] does not error in constant propagation 2236*da0073e9SAndroid Build Coastguard Worker def bad_index(x): 2237*da0073e9SAndroid Build Coastguard Worker # type: (bool) 2238*da0073e9SAndroid Build Coastguard Worker y = 0 2239*da0073e9SAndroid Build Coastguard Worker if x: 2240*da0073e9SAndroid Build Coastguard Worker a = [1, 2, 3] 2241*da0073e9SAndroid Build Coastguard Worker y = a[4] 2242*da0073e9SAndroid Build Coastguard Worker return y 2243*da0073e9SAndroid Build Coastguard Worker 2244*da0073e9SAndroid Build Coastguard Worker self.checkScript(bad_index, (False,)) 2245*da0073e9SAndroid Build Coastguard Worker 2246*da0073e9SAndroid Build Coastguard Worker def test_constant_prop_aliasing_type(self): 2247*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 2248*da0073e9SAndroid Build Coastguard Worker def foo(): 2249*da0073e9SAndroid Build Coastguard Worker return len([1]), len(torch.tensor([2])) 2250*da0073e9SAndroid Build Coastguard Worker 2251*da0073e9SAndroid Build Coastguard Worker FileCheck().check_dag("aten::tensor").check_dag("aten::len").run(foo.graph) 2252*da0073e9SAndroid Build Coastguard Worker 2253*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 2254*da0073e9SAndroid Build Coastguard Worker def fn(): 2255*da0073e9SAndroid Build Coastguard Worker if 1 == 1: 2256*da0073e9SAndroid Build Coastguard Worker return 1 2257*da0073e9SAndroid Build Coastguard Worker else: 2258*da0073e9SAndroid Build Coastguard Worker return 2 2259*da0073e9SAndroid Build Coastguard Worker 2260*da0073e9SAndroid Build Coastguard Worker FileCheck().check_not("prim::If").run(fn.graph) 2261*da0073e9SAndroid Build Coastguard Worker 2262*da0073e9SAndroid Build Coastguard Worker def test_unchecked_cast(self): 2263*da0073e9SAndroid Build Coastguard Worker def test(cond): 2264*da0073e9SAndroid Build Coastguard Worker # type: (bool) 2265*da0073e9SAndroid Build Coastguard Worker a = torch.tensor([10]) 2266*da0073e9SAndroid Build Coastguard Worker if cond: 2267*da0073e9SAndroid Build Coastguard Worker b = None 2268*da0073e9SAndroid Build Coastguard Worker else: 2269*da0073e9SAndroid Build Coastguard Worker b = a 2270*da0073e9SAndroid Build Coastguard Worker if b is not None: 2271*da0073e9SAndroid Build Coastguard Worker b[0] = 5 2272*da0073e9SAndroid Build Coastguard Worker return a.int() 2273*da0073e9SAndroid Build Coastguard Worker 2274*da0073e9SAndroid Build Coastguard Worker self.checkScript(test, (True,)) 2275*da0073e9SAndroid Build Coastguard Worker self.checkScript(test, (False,)) 2276*da0073e9SAndroid Build Coastguard Worker 2277*da0073e9SAndroid Build Coastguard Worker def test_constant_prop_if_constant(self): 2278*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 2279*da0073e9SAndroid Build Coastguard Worker def constant_prop(a, b): 2280*da0073e9SAndroid Build Coastguard Worker c0 = 1 2281*da0073e9SAndroid Build Coastguard Worker c1 = 1 2282*da0073e9SAndroid Build Coastguard Worker c2 = 1 2283*da0073e9SAndroid Build Coastguard Worker if bool(a): # -> c0, c1 2284*da0073e9SAndroid Build Coastguard Worker if bool(b): # -> c0 2285*da0073e9SAndroid Build Coastguard Worker if 1 == 1: # -> c0 2286*da0073e9SAndroid Build Coastguard Worker c0 = c0 + 1 2287*da0073e9SAndroid Build Coastguard Worker if 1 == 2: 2288*da0073e9SAndroid Build Coastguard Worker c1 = c1 + 1 2289*da0073e9SAndroid Build Coastguard Worker c2 = c2 + 1 2290*da0073e9SAndroid Build Coastguard Worker else: # -> c0, c1 2291*da0073e9SAndroid Build Coastguard Worker c1 = c1 + 1 2292*da0073e9SAndroid Build Coastguard Worker 2293*da0073e9SAndroid Build Coastguard Worker if 1 == 1: # inlined 2294*da0073e9SAndroid Build Coastguard Worker c0 = c0 + 1 # dynamic 2295*da0073e9SAndroid Build Coastguard Worker c2 = c2 + 4 # set to 5 2296*da0073e9SAndroid Build Coastguard Worker return a + c0 + c1 + c2 2297*da0073e9SAndroid Build Coastguard Worker 2298*da0073e9SAndroid Build Coastguard Worker graph = constant_prop.graph 2299*da0073e9SAndroid Build Coastguard Worker self.run_pass('constant_propagation', graph) 2300*da0073e9SAndroid Build Coastguard Worker ifs = graph.findAllNodes("prim::If", recurse=False) 2301*da0073e9SAndroid Build Coastguard Worker snd_if_inlined = len(ifs) == 1 2302*da0073e9SAndroid Build Coastguard Worker self.assertTrue(snd_if_inlined) 2303*da0073e9SAndroid Build Coastguard Worker first_if = ifs[0] 2304*da0073e9SAndroid Build Coastguard Worker self.assertTrue(first_if.outputsSize() == 2) 2305*da0073e9SAndroid Build Coastguard Worker second_if = first_if.findNode("prim::If", recurse=False) 2306*da0073e9SAndroid Build Coastguard Worker self.assertTrue(second_if.outputsSize() == 1) 2307*da0073e9SAndroid Build Coastguard Worker self.assertTrue(second_if.findNode("prim::If") is None) 2308*da0073e9SAndroid Build Coastguard Worker 2309*da0073e9SAndroid Build Coastguard Worker def test_constant_prop_loop_constant(self): 2310*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 2311*da0073e9SAndroid Build Coastguard Worker def constant_prop(cond, iter): 2312*da0073e9SAndroid Build Coastguard Worker # type: (bool, int) -> int 2313*da0073e9SAndroid Build Coastguard Worker b = 0 2314*da0073e9SAndroid Build Coastguard Worker while True: 2315*da0073e9SAndroid Build Coastguard Worker print("stays") 2316*da0073e9SAndroid Build Coastguard Worker for _ in range(2): 2317*da0073e9SAndroid Build Coastguard Worker print("stays") 2318*da0073e9SAndroid Build Coastguard Worker for _ in range(iter): 2319*da0073e9SAndroid Build Coastguard Worker print("stays") 2320*da0073e9SAndroid Build Coastguard Worker while cond: 2321*da0073e9SAndroid Build Coastguard Worker print("stays") 2322*da0073e9SAndroid Build Coastguard Worker while False: 2323*da0073e9SAndroid Build Coastguard Worker print("removed") 2324*da0073e9SAndroid Build Coastguard Worker for _i in range(0): 2325*da0073e9SAndroid Build Coastguard Worker print("removed") 2326*da0073e9SAndroid Build Coastguard Worker for _i in range(-4): 2327*da0073e9SAndroid Build Coastguard Worker print("removed") 2328*da0073e9SAndroid Build Coastguard Worker return b 2329*da0073e9SAndroid Build Coastguard Worker 2330*da0073e9SAndroid Build Coastguard Worker self.run_pass('constant_propagation', constant_prop.graph) 2331*da0073e9SAndroid Build Coastguard Worker graph = canonical(constant_prop.graph) 2332*da0073e9SAndroid Build Coastguard Worker self.assertTrue(graph.count("removed") == 0) 2333*da0073e9SAndroid Build Coastguard Worker self.assertTrue(graph.count("stays") == 1) # constant gets pooled 2334*da0073e9SAndroid Build Coastguard Worker self.assertTrue(graph.count("prim::Print") == 4) 2335*da0073e9SAndroid Build Coastguard Worker 2336*da0073e9SAndroid Build Coastguard Worker def test_constant_prop_remove_output(self): 2337*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 2338*da0073e9SAndroid Build Coastguard Worker def constant_prop(iter): 2339*da0073e9SAndroid Build Coastguard Worker # type: (int) -> None 2340*da0073e9SAndroid Build Coastguard Worker a = 1 2341*da0073e9SAndroid Build Coastguard Worker b = 1 2342*da0073e9SAndroid Build Coastguard Worker c = 1 2343*da0073e9SAndroid Build Coastguard Worker for i in range(iter): 2344*da0073e9SAndroid Build Coastguard Worker if 1 == 2: 2345*da0073e9SAndroid Build Coastguard Worker a = 10 2346*da0073e9SAndroid Build Coastguard Worker if i == 5: 2347*da0073e9SAndroid Build Coastguard Worker b = 2 2348*da0073e9SAndroid Build Coastguard Worker c = 3 2349*da0073e9SAndroid Build Coastguard Worker print(a, b, c) 2350*da0073e9SAndroid Build Coastguard Worker 2351*da0073e9SAndroid Build Coastguard Worker graph = constant_prop.graph 2352*da0073e9SAndroid Build Coastguard Worker self.run_pass('constant_propagation', graph) 2353*da0073e9SAndroid Build Coastguard Worker self.assertTrue(graph.findNode("prim::Loop").outputsSize() == 2) 2354*da0073e9SAndroid Build Coastguard Worker 2355*da0073e9SAndroid Build Coastguard Worker # TODO(gmagogsfm): Refactor this test to reduce complexity. 2356*da0073e9SAndroid Build Coastguard Worker def test_constant_insertion(self): 2357*da0073e9SAndroid Build Coastguard Worker funcs_template = dedent(''' 2358*da0073e9SAndroid Build Coastguard Worker def func(): 2359*da0073e9SAndroid Build Coastguard Worker return {constant_constructor} 2360*da0073e9SAndroid Build Coastguard Worker ''') 2361*da0073e9SAndroid Build Coastguard Worker 2362*da0073e9SAndroid Build Coastguard Worker # constants: primitives: int, double, bool, str, lists of primitives, 2363*da0073e9SAndroid Build Coastguard Worker # and tuples 2364*da0073e9SAndroid Build Coastguard Worker def check_constant(constant_constructor): 2365*da0073e9SAndroid Build Coastguard Worker scope = {} 2366*da0073e9SAndroid Build Coastguard Worker funcs_str = funcs_template.format(constant_constructor=constant_constructor) 2367*da0073e9SAndroid Build Coastguard Worker execWrapper(funcs_str, globals(), scope) 2368*da0073e9SAndroid Build Coastguard Worker cu = torch.jit.CompilationUnit(funcs_str) 2369*da0073e9SAndroid Build Coastguard Worker f_script = cu.func 2370*da0073e9SAndroid Build Coastguard Worker self.run_pass('constant_propagation', f_script.graph) 2371*da0073e9SAndroid Build Coastguard Worker FileCheck().check_count("prim::Constant", 1, exactly=True).run(f_script.graph) 2372*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scope['func'](), f_script()) 2373*da0073e9SAndroid Build Coastguard Worker imported = self.getExportImportCopy(f_script) 2374*da0073e9SAndroid Build Coastguard Worker self.assertEqual(imported(), f_script()) 2375*da0073e9SAndroid Build Coastguard Worker 2376*da0073e9SAndroid Build Coastguard Worker constants = ["None", "-.5", "0", "1", "True", "False", "''", "'a'", "'b'", "torch.tensor(1)", 2377*da0073e9SAndroid Build Coastguard Worker "[True, False]", "[0., .5]", "[torch.tensor(4), torch.tensor(2)]", "[0, 1]", "['0', '1']", 2378*da0073e9SAndroid Build Coastguard Worker "[True, None]", "[.5, None, .2]"] 2379*da0073e9SAndroid Build Coastguard Worker 2380*da0073e9SAndroid Build Coastguard Worker for type in ["Tensor", "str", "int", "float", "bool"]: 2381*da0073e9SAndroid Build Coastguard Worker constants.append("torch.jit.annotate(List[ " + type + "], [])") 2382*da0073e9SAndroid Build Coastguard Worker 2383*da0073e9SAndroid Build Coastguard Worker for constant in constants: 2384*da0073e9SAndroid Build Coastguard Worker check_constant(constant) 2385*da0073e9SAndroid Build Coastguard Worker 2386*da0073e9SAndroid Build Coastguard Worker for key_type in ["str", "int", "float"]: 2387*da0073e9SAndroid Build Coastguard Worker for value_type in ["Tensor", "bool", "str", "int", "float"]: 2388*da0073e9SAndroid Build Coastguard Worker check_constant("torch.jit.annotate(Dict[ " + key_type + ", " + value_type + "], {})") 2389*da0073e9SAndroid Build Coastguard Worker check_constant("torch.jit.annotate(Dict[ " + key_type + ", Optional[" + value_type + "]], {})") 2390*da0073e9SAndroid Build Coastguard Worker 2391*da0073e9SAndroid Build Coastguard Worker for i in range(len(constants)): 2392*da0073e9SAndroid Build Coastguard Worker for j in range(i + 1, len(constants)): 2393*da0073e9SAndroid Build Coastguard Worker tup_constant = constants[i] + ", " + constants[j] 2394*da0073e9SAndroid Build Coastguard Worker check_constant(tup_constant) 2395*da0073e9SAndroid Build Coastguard Worker 2396*da0073e9SAndroid Build Coastguard Worker dict_constants = [] 2397*da0073e9SAndroid Build Coastguard Worker for i in range(len(constants)): 2398*da0073e9SAndroid Build Coastguard Worker # check_constant constructs the second dict with another Tensor 2399*da0073e9SAndroid Build Coastguard Worker # which fails the comparison 2400*da0073e9SAndroid Build Coastguard Worker if not isinstance(eval(constants[i]), (str, int, float)): 2401*da0073e9SAndroid Build Coastguard Worker continue 2402*da0073e9SAndroid Build Coastguard Worker for j in range(len(constants)): 2403*da0073e9SAndroid Build Coastguard Worker dict_constant = "{ " + constants[i] + ": " + constants[j] + "}" 2404*da0073e9SAndroid Build Coastguard Worker check_constant(dict_constant) 2405*da0073e9SAndroid Build Coastguard Worker dict_constants.append(dict_constant) 2406*da0073e9SAndroid Build Coastguard Worker constants = constants + dict_constants 2407*da0073e9SAndroid Build Coastguard Worker 2408*da0073e9SAndroid Build Coastguard Worker # testing node hashing 2409*da0073e9SAndroid Build Coastguard Worker funcs_template = dedent(''' 2410*da0073e9SAndroid Build Coastguard Worker def func(): 2411*da0073e9SAndroid Build Coastguard Worker print({constant_constructor}) 2412*da0073e9SAndroid Build Coastguard Worker ''') 2413*da0073e9SAndroid Build Coastguard Worker single_elem_tuples = ("(" + x + ",)" for x in constants) 2414*da0073e9SAndroid Build Coastguard Worker input_arg = ", ".join(single_elem_tuples) 2415*da0073e9SAndroid Build Coastguard Worker scope = {} 2416*da0073e9SAndroid Build Coastguard Worker funcs_str = funcs_template.format(constant_constructor=input_arg) 2417*da0073e9SAndroid Build Coastguard Worker execWrapper(funcs_str, globals(), scope) 2418*da0073e9SAndroid Build Coastguard Worker cu = torch.jit.CompilationUnit(funcs_str) 2419*da0073e9SAndroid Build Coastguard Worker f_script = cu.func 2420*da0073e9SAndroid Build Coastguard Worker self.run_pass('constant_propagation', f_script.graph) 2421*da0073e9SAndroid Build Coastguard Worker # prim::None return adds one constant 2422*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(constants) + 1, str(f_script.graph).count("prim::Constant")) 2423*da0073e9SAndroid Build Coastguard Worker self.run_pass('cse', f_script.graph) 2424*da0073e9SAndroid Build Coastguard Worker # node hashing correctly working, no CSE occurs 2425*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(constants) + 1, str(f_script.graph).count("prim::Constant")) 2426*da0073e9SAndroid Build Coastguard Worker 2427*da0073e9SAndroid Build Coastguard Worker funcs_template = dedent(''' 2428*da0073e9SAndroid Build Coastguard Worker def func(): 2429*da0073e9SAndroid Build Coastguard Worker a = {constant_constructor} 2430*da0073e9SAndroid Build Coastguard Worker print(a) 2431*da0073e9SAndroid Build Coastguard Worker b = {constant_constructor} 2432*da0073e9SAndroid Build Coastguard Worker print(b) 2433*da0073e9SAndroid Build Coastguard Worker ''') 2434*da0073e9SAndroid Build Coastguard Worker 2435*da0073e9SAndroid Build Coastguard Worker # generate dicts with built-in types (excluding torch.Tensor) 2436*da0073e9SAndroid Build Coastguard Worker xprod = itertools.product(constants, constants) 2437*da0073e9SAndroid Build Coastguard Worker 2438*da0073e9SAndroid Build Coastguard Worker # test that equal tuples and dicts correctly work with node hashing 2439*da0073e9SAndroid Build Coastguard Worker for tup in ("(" + x + ",)" for x in constants): 2440*da0073e9SAndroid Build Coastguard Worker funcs_str = funcs_template.format(constant_constructor=tup) 2441*da0073e9SAndroid Build Coastguard Worker scope = {} 2442*da0073e9SAndroid Build Coastguard Worker execWrapper(funcs_str, globals(), scope) 2443*da0073e9SAndroid Build Coastguard Worker cu = torch.jit.CompilationUnit(funcs_str) 2444*da0073e9SAndroid Build Coastguard Worker f_script = cu.func 2445*da0073e9SAndroid Build Coastguard Worker self.run_pass('constant_propagation_immutable_types', f_script.graph) 2446*da0073e9SAndroid Build Coastguard Worker num_constants = str(f_script.graph).count("prim::Constant") 2447*da0073e9SAndroid Build Coastguard Worker self.run_pass('cse', f_script.graph) 2448*da0073e9SAndroid Build Coastguard Worker FileCheck().check_count("prim::Constant", num_constants, exactly=True).run(f_script.graph) 2449*da0073e9SAndroid Build Coastguard Worker 2450*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not RUN_CUDA, "requires CUDA") 2451*da0073e9SAndroid Build Coastguard Worker def test_cuda_export_restore(self): 2452*da0073e9SAndroid Build Coastguard Worker class Sub(torch.jit.ScriptModule): 2453*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 2454*da0073e9SAndroid Build Coastguard Worker super().__init__() 2455*da0073e9SAndroid Build Coastguard Worker self.weight = nn.Parameter(torch.randn(3, 4)) 2456*da0073e9SAndroid Build Coastguard Worker 2457*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 2458*da0073e9SAndroid Build Coastguard Worker def forward(self, thing): 2459*da0073e9SAndroid Build Coastguard Worker return self.weight + thing 2460*da0073e9SAndroid Build Coastguard Worker 2461*da0073e9SAndroid Build Coastguard Worker class M(torch.jit.ScriptModule): 2462*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 2463*da0073e9SAndroid Build Coastguard Worker super().__init__() 2464*da0073e9SAndroid Build Coastguard Worker self.mod = Sub() 2465*da0073e9SAndroid Build Coastguard Worker 2466*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 2467*da0073e9SAndroid Build Coastguard Worker def forward(self, v): 2468*da0073e9SAndroid Build Coastguard Worker return self.mod(v) 2469*da0073e9SAndroid Build Coastguard Worker m = M() 2470*da0073e9SAndroid Build Coastguard Worker m.cuda() 2471*da0073e9SAndroid Build Coastguard Worker m2 = self.getExportImportCopy(m) 2472*da0073e9SAndroid Build Coastguard Worker m2.cuda() 2473*da0073e9SAndroid Build Coastguard Worker input = torch.rand(3, 4).cuda() 2474*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m(input), m2(input)) 2475*da0073e9SAndroid Build Coastguard Worker 2476*da0073e9SAndroid Build Coastguard Worker @slowTest 2477*da0073e9SAndroid Build Coastguard Worker def test_export_batchnorm(self): 2478*da0073e9SAndroid Build Coastguard Worker for mode in ['eval', 'train']: 2479*da0073e9SAndroid Build Coastguard Worker for clazz in [ 2480*da0073e9SAndroid Build Coastguard Worker torch.nn.BatchNorm1d(100), 2481*da0073e9SAndroid Build Coastguard Worker torch.nn.BatchNorm1d(100, affine=False), 2482*da0073e9SAndroid Build Coastguard Worker torch.nn.BatchNorm2d(100), 2483*da0073e9SAndroid Build Coastguard Worker torch.nn.BatchNorm2d(100, affine=False)]: 2484*da0073e9SAndroid Build Coastguard Worker getattr(clazz, mode)() 2485*da0073e9SAndroid Build Coastguard Worker input = torch.randn(20, 100) if isinstance(clazz, torch.nn.BatchNorm1d) else \ 2486*da0073e9SAndroid Build Coastguard Worker torch.randn(20, 100, 35, 45) 2487*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.trace(clazz, (input,)) 2488*da0073e9SAndroid Build Coastguard Worker imported = self.getExportImportCopy(traced) 2489*da0073e9SAndroid Build Coastguard Worker x = torch.randn(20, 100) if isinstance(clazz, torch.nn.BatchNorm1d) else \ 2490*da0073e9SAndroid Build Coastguard Worker torch.randn(20, 100, 35, 45) 2491*da0073e9SAndroid Build Coastguard Worker self.assertEqual(traced(x), imported(x)) 2492*da0073e9SAndroid Build Coastguard Worker 2493*da0073e9SAndroid Build Coastguard Worker def test_export_rnn(self): 2494*da0073e9SAndroid Build Coastguard Worker for clazz in [nn.RNN(10, 20, 2), nn.GRU(10, 20, 2)]: 2495*da0073e9SAndroid Build Coastguard Worker class RNNTest(torch.nn.Module): 2496*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 2497*da0073e9SAndroid Build Coastguard Worker super().__init__() 2498*da0073e9SAndroid Build Coastguard Worker self.rnn = clazz 2499*da0073e9SAndroid Build Coastguard Worker 2500*da0073e9SAndroid Build Coastguard Worker def forward(self, x, lengths, h0): 2501*da0073e9SAndroid Build Coastguard Worker packed = torch.nn.utils.rnn.pack_padded_sequence(x, lengths) 2502*da0073e9SAndroid Build Coastguard Worker out, h = self.rnn(packed, h0) 2503*da0073e9SAndroid Build Coastguard Worker padded_outs, _ = torch.nn.utils.rnn.pad_packed_sequence(out) 2504*da0073e9SAndroid Build Coastguard Worker return padded_outs 2505*da0073e9SAndroid Build Coastguard Worker 2506*da0073e9SAndroid Build Coastguard Worker test = RNNTest() 2507*da0073e9SAndroid Build Coastguard Worker 2508*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.trace(test, (torch.randn(5, 3, 10), torch.LongTensor([3, 2, 1]), torch.randn(2, 3, 20))) 2509*da0073e9SAndroid Build Coastguard Worker imported = self.getExportImportCopy(traced) 2510*da0073e9SAndroid Build Coastguard Worker # NB: We make sure to pass in a batch with a different max sequence 2511*da0073e9SAndroid Build Coastguard Worker # length to ensure that the argument stashing for pad_packed works 2512*da0073e9SAndroid Build Coastguard Worker # properly. 2513*da0073e9SAndroid Build Coastguard Worker x, lengths, h0 = torch.randn(7, 4, 10), torch.LongTensor([7, 3, 2, 1]), torch.randn(2, 4, 20) 2514*da0073e9SAndroid Build Coastguard Worker self.assertEqual(traced(x, lengths, h0), imported(x, lengths, h0)) 2515*da0073e9SAndroid Build Coastguard Worker 2516*da0073e9SAndroid Build Coastguard Worker def test_export_lstm(self): 2517*da0073e9SAndroid Build Coastguard Worker class LSTMTest(torch.nn.Module): 2518*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 2519*da0073e9SAndroid Build Coastguard Worker super().__init__() 2520*da0073e9SAndroid Build Coastguard Worker self.rnn = nn.LSTM(10, 20, 2) 2521*da0073e9SAndroid Build Coastguard Worker 2522*da0073e9SAndroid Build Coastguard Worker def forward(self, x, lengths, hiddens): 2523*da0073e9SAndroid Build Coastguard Worker h0, c0 = hiddens 2524*da0073e9SAndroid Build Coastguard Worker packed = torch.nn.utils.rnn.pack_padded_sequence(x, lengths) 2525*da0073e9SAndroid Build Coastguard Worker out, (h, c) = self.rnn(packed, (h0, c0)) 2526*da0073e9SAndroid Build Coastguard Worker padded_outs, _ = torch.nn.utils.rnn.pad_packed_sequence(out) 2527*da0073e9SAndroid Build Coastguard Worker return padded_outs 2528*da0073e9SAndroid Build Coastguard Worker 2529*da0073e9SAndroid Build Coastguard Worker test = LSTMTest() 2530*da0073e9SAndroid Build Coastguard Worker 2531*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.trace(test, (torch.randn(5, 3, 10), 2532*da0073e9SAndroid Build Coastguard Worker torch.LongTensor([3, 2, 1]), 2533*da0073e9SAndroid Build Coastguard Worker (torch.randn(2, 3, 20), torch.randn(2, 3, 20)))) 2534*da0073e9SAndroid Build Coastguard Worker imported = self.getExportImportCopy(traced) 2535*da0073e9SAndroid Build Coastguard Worker x, lengths, h0, c0 = \ 2536*da0073e9SAndroid Build Coastguard Worker torch.randn(7, 3, 10), torch.LongTensor([7, 5, 2]), torch.randn(2, 3, 20), torch.randn(2, 3, 20) 2537*da0073e9SAndroid Build Coastguard Worker self.assertEqual(traced(x, lengths, (h0, c0)), imported(x, lengths, (h0, c0))) 2538*da0073e9SAndroid Build Coastguard Worker 2539*da0073e9SAndroid Build Coastguard Worker def test_unique_state_dict(self): 2540*da0073e9SAndroid Build Coastguard Worker class MyModule(torch.nn.Module): 2541*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 2542*da0073e9SAndroid Build Coastguard Worker super().__init__() 2543*da0073e9SAndroid Build Coastguard Worker shared_param = torch.nn.Parameter(torch.ones(1)) 2544*da0073e9SAndroid Build Coastguard Worker self.register_parameter('w1', shared_param) 2545*da0073e9SAndroid Build Coastguard Worker self.register_parameter('w2', shared_param) 2546*da0073e9SAndroid Build Coastguard Worker 2547*da0073e9SAndroid Build Coastguard Worker def forward(self, input): 2548*da0073e9SAndroid Build Coastguard Worker return input + self.w1 + self.w2 2549*da0073e9SAndroid Build Coastguard Worker 2550*da0073e9SAndroid Build Coastguard Worker model = MyModule() 2551*da0073e9SAndroid Build Coastguard Worker unittest.TestCase.assertEqual( 2552*da0073e9SAndroid Build Coastguard Worker self, len(torch.jit._unique_state_dict(model, keep_vars=False)), 1) 2553*da0073e9SAndroid Build Coastguard Worker unittest.TestCase.assertEqual( 2554*da0073e9SAndroid Build Coastguard Worker self, len(torch.jit._unique_state_dict(model, keep_vars=True)), 1) 2555*da0073e9SAndroid Build Coastguard Worker 2556*da0073e9SAndroid Build Coastguard Worker def test_export_dropout(self): 2557*da0073e9SAndroid Build Coastguard Worker test = torch.nn.Dropout() 2558*da0073e9SAndroid Build Coastguard Worker test.eval() 2559*da0073e9SAndroid Build Coastguard Worker 2560*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.trace(test, (torch.rand(3, 4),), check_trace=False) 2561*da0073e9SAndroid Build Coastguard Worker imported = self.getExportImportCopy(traced) 2562*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, 4) 2563*da0073e9SAndroid Build Coastguard Worker self.assertEqual(traced(x), imported(x)) 2564*da0073e9SAndroid Build Coastguard Worker 2565*da0073e9SAndroid Build Coastguard Worker def test_pretty_printer(self): 2566*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 2567*da0073e9SAndroid Build Coastguard Worker def if_test(a, b): 2568*da0073e9SAndroid Build Coastguard Worker # FIXME: use 0 instead of a. 2569*da0073e9SAndroid Build Coastguard Worker # c = 0 2570*da0073e9SAndroid Build Coastguard Worker c = a 2571*da0073e9SAndroid Build Coastguard Worker if bool(a < b): 2572*da0073e9SAndroid Build Coastguard Worker c = b 2573*da0073e9SAndroid Build Coastguard Worker else: 2574*da0073e9SAndroid Build Coastguard Worker c = a 2575*da0073e9SAndroid Build Coastguard Worker return c 2576*da0073e9SAndroid Build Coastguard Worker 2577*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 2578*da0073e9SAndroid Build Coastguard Worker def if_one(a, b): 2579*da0073e9SAndroid Build Coastguard Worker c = b 2580*da0073e9SAndroid Build Coastguard Worker if bool(a < b): 2581*da0073e9SAndroid Build Coastguard Worker c = a 2582*da0073e9SAndroid Build Coastguard Worker return c 2583*da0073e9SAndroid Build Coastguard Worker 2584*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 2585*da0073e9SAndroid Build Coastguard Worker def while_test(a, i): 2586*da0073e9SAndroid Build Coastguard Worker while bool(i < 3): 2587*da0073e9SAndroid Build Coastguard Worker a *= a 2588*da0073e9SAndroid Build Coastguard Worker i += 1 2589*da0073e9SAndroid Build Coastguard Worker return a 2590*da0073e9SAndroid Build Coastguard Worker 2591*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 2592*da0073e9SAndroid Build Coastguard Worker def while_if_test(a, b): 2593*da0073e9SAndroid Build Coastguard Worker c = 0 2594*da0073e9SAndroid Build Coastguard Worker while bool(a < 10): 2595*da0073e9SAndroid Build Coastguard Worker a = a + 1 2596*da0073e9SAndroid Build Coastguard Worker b = b + 1 2597*da0073e9SAndroid Build Coastguard Worker if bool(a > b): 2598*da0073e9SAndroid Build Coastguard Worker c = 2 2599*da0073e9SAndroid Build Coastguard Worker else: 2600*da0073e9SAndroid Build Coastguard Worker c = 3 2601*da0073e9SAndroid Build Coastguard Worker return a + 1 + c 2602*da0073e9SAndroid Build Coastguard Worker 2603*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 2604*da0073e9SAndroid Build Coastguard Worker def loop_use_test(y): 2605*da0073e9SAndroid Build Coastguard Worker x = y + 1 2606*da0073e9SAndroid Build Coastguard Worker z = x + 5 2607*da0073e9SAndroid Build Coastguard Worker while bool(y < 8): 2608*da0073e9SAndroid Build Coastguard Worker y += 1 2609*da0073e9SAndroid Build Coastguard Worker z = x 2610*da0073e9SAndroid Build Coastguard Worker return x, z 2611*da0073e9SAndroid Build Coastguard Worker 2612*da0073e9SAndroid Build Coastguard Worker @torch.jit.ignore 2613*da0073e9SAndroid Build Coastguard Worker def python_fn(x): 2614*da0073e9SAndroid Build Coastguard Worker return x + 10 2615*da0073e9SAndroid Build Coastguard Worker 2616*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 2617*da0073e9SAndroid Build Coastguard Worker def python_op_name_test(y): 2618*da0073e9SAndroid Build Coastguard Worker return python_fn(y) 2619*da0073e9SAndroid Build Coastguard Worker 2620*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 2621*da0073e9SAndroid Build Coastguard Worker def empty_int_list_test(y): 2622*da0073e9SAndroid Build Coastguard Worker x = torch.jit.annotate(List[int], []) 2623*da0073e9SAndroid Build Coastguard Worker return x[0] 2624*da0073e9SAndroid Build Coastguard Worker 2625*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 2626*da0073e9SAndroid Build Coastguard Worker def empty_float_list_test(y): 2627*da0073e9SAndroid Build Coastguard Worker return [1.0, 2.0, 3.0] 2628*da0073e9SAndroid Build Coastguard Worker 2629*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 2630*da0073e9SAndroid Build Coastguard Worker def print_weird_test(y): 2631*da0073e9SAndroid Build Coastguard Worker print("hi\016") 2632*da0073e9SAndroid Build Coastguard Worker 2633*da0073e9SAndroid Build Coastguard Worker self.assertExpected(if_test.code, "if_test") 2634*da0073e9SAndroid Build Coastguard Worker self.assertExpected(if_one.code, "if_one") 2635*da0073e9SAndroid Build Coastguard Worker self.assertExpected(while_test.code, "while_test") 2636*da0073e9SAndroid Build Coastguard Worker self.assertExpected(while_if_test.code, "while_if_test") 2637*da0073e9SAndroid Build Coastguard Worker self.assertExpected(loop_use_test.code, "loop_use_test") 2638*da0073e9SAndroid Build Coastguard Worker self.assertExpected(python_op_name_test.code, "python_op_name_test") 2639*da0073e9SAndroid Build Coastguard Worker self.assertExpected(empty_int_list_test.code, "empty_int_list_test") 2640*da0073e9SAndroid Build Coastguard Worker self.assertExpected(empty_float_list_test.code, "empty_float_list_test") 2641*da0073e9SAndroid Build Coastguard Worker self.assertExpected(print_weird_test.code, "print_weird_test") 2642*da0073e9SAndroid Build Coastguard Worker 2643*da0073e9SAndroid Build Coastguard Worker def test_cu_escaped_number(self): 2644*da0073e9SAndroid Build Coastguard Worker cu = torch.jit.CompilationUnit(''' 2645*da0073e9SAndroid Build Coastguard Worker def foo(a): 2646*da0073e9SAndroid Build Coastguard Worker print("hi\016") 2647*da0073e9SAndroid Build Coastguard Worker ''') 2648*da0073e9SAndroid Build Coastguard Worker self.assertExpected(cu.foo.code) 2649*da0073e9SAndroid Build Coastguard Worker 2650*da0073e9SAndroid Build Coastguard Worker def test_import_method(self): 2651*da0073e9SAndroid Build Coastguard Worker with torch._jit_internal._disable_emit_hooks(): 2652*da0073e9SAndroid Build Coastguard Worker class Foo(torch.jit.ScriptModule): 2653*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 2654*da0073e9SAndroid Build Coastguard Worker def forward(self, x, y): 2655*da0073e9SAndroid Build Coastguard Worker return 2 * x + y 2656*da0073e9SAndroid Build Coastguard Worker 2657*da0073e9SAndroid Build Coastguard Worker foo = Foo() 2658*da0073e9SAndroid Build Coastguard Worker buffer = io.BytesIO() 2659*da0073e9SAndroid Build Coastguard Worker torch.jit.save(foo, buffer) 2660*da0073e9SAndroid Build Coastguard Worker 2661*da0073e9SAndroid Build Coastguard Worker buffer.seek(0) 2662*da0073e9SAndroid Build Coastguard Worker foo_loaded = torch.jit.load(buffer) 2663*da0073e9SAndroid Build Coastguard Worker self.assertExpected(foo_loaded.forward.code) 2664*da0073e9SAndroid Build Coastguard Worker 2665*da0073e9SAndroid Build Coastguard Worker @unittest.skip("temporarily disable the test for fwd compatibility") 2666*da0073e9SAndroid Build Coastguard Worker def test_non_ascii_string(self): 2667*da0073e9SAndroid Build Coastguard Worker class Foo(torch.jit.ScriptModule): 2668*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 2669*da0073e9SAndroid Build Coastguard Worker super().__init__() 2670*da0073e9SAndroid Build Coastguard Worker self.a = "Over \u0e55\u0e57 57" 2671*da0073e9SAndroid Build Coastguard Worker 2672*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 2673*da0073e9SAndroid Build Coastguard Worker def forward(self, x, y): 2674*da0073e9SAndroid Build Coastguard Worker return self.a + "hi\xA1" 2675*da0073e9SAndroid Build Coastguard Worker 2676*da0073e9SAndroid Build Coastguard Worker foo = Foo() 2677*da0073e9SAndroid Build Coastguard Worker buffer = io.BytesIO() 2678*da0073e9SAndroid Build Coastguard Worker torch.jit.save(foo, buffer) 2679*da0073e9SAndroid Build Coastguard Worker 2680*da0073e9SAndroid Build Coastguard Worker buffer.seek(0) 2681*da0073e9SAndroid Build Coastguard Worker foo_loaded = torch.jit.load(buffer) 2682*da0073e9SAndroid Build Coastguard Worker self.assertExpected(foo_loaded.forward.code) 2683*da0073e9SAndroid Build Coastguard Worker 2684*da0073e9SAndroid Build Coastguard Worker def test_function_default_values(self): 2685*da0073e9SAndroid Build Coastguard Worker outer_var = torch.tensor(20) 2686*da0073e9SAndroid Build Coastguard Worker outer_var2 = torch.tensor(30) 2687*da0073e9SAndroid Build Coastguard Worker a = torch.tensor(0.5) 2688*da0073e9SAndroid Build Coastguard Worker b = torch.tensor(10) 2689*da0073e9SAndroid Build Coastguard Worker 2690*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 2691*da0073e9SAndroid Build Coastguard Worker def simple_fn(x, a=a, b=b, c=outer_var + outer_var2): 2692*da0073e9SAndroid Build Coastguard Worker return x + a + b + c 2693*da0073e9SAndroid Build Coastguard Worker 2694*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2695*da0073e9SAndroid Build Coastguard Worker simple_fn(torch.ones(1)), 2696*da0073e9SAndroid Build Coastguard Worker torch.ones(1) + 0.5 + 10 + (20 + 30)) 2697*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2698*da0073e9SAndroid Build Coastguard Worker simple_fn(torch.ones(1), torch.tensor(1), torch.tensor(3), torch.tensor(4)), 2699*da0073e9SAndroid Build Coastguard Worker torch.ones(1) + 1 + 3 + 4) 2700*da0073e9SAndroid Build Coastguard Worker 2701*da0073e9SAndroid Build Coastguard Worker outer_c = torch.tensor(9) 2702*da0073e9SAndroid Build Coastguard Worker outer_flag = torch.tensor(False) 2703*da0073e9SAndroid Build Coastguard Worker 2704*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 2705*da0073e9SAndroid Build Coastguard Worker def bool_fn(x, a=outer_c, flag=outer_flag): 2706*da0073e9SAndroid Build Coastguard Worker if bool(flag): 2707*da0073e9SAndroid Build Coastguard Worker result = x 2708*da0073e9SAndroid Build Coastguard Worker else: 2709*da0073e9SAndroid Build Coastguard Worker result = x + a 2710*da0073e9SAndroid Build Coastguard Worker return result 2711*da0073e9SAndroid Build Coastguard Worker 2712*da0073e9SAndroid Build Coastguard Worker self.assertEqual(bool_fn(torch.ones(1)), torch.ones(1) + 9) 2713*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2714*da0073e9SAndroid Build Coastguard Worker bool_fn(torch.ones(1), torch.tensor(1), torch.tensor(True)), 2715*da0073e9SAndroid Build Coastguard Worker torch.ones(1)) 2716*da0073e9SAndroid Build Coastguard Worker 2717*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 2718*da0073e9SAndroid Build Coastguard Worker def none_fn(x=None): 2719*da0073e9SAndroid Build Coastguard Worker # type: (Optional[int]) -> Optional[int] 2720*da0073e9SAndroid Build Coastguard Worker return x 2721*da0073e9SAndroid Build Coastguard Worker 2722*da0073e9SAndroid Build Coastguard Worker self.assertEqual(none_fn(), None) 2723*da0073e9SAndroid Build Coastguard Worker self.assertEqual(none_fn(1), 1) 2724*da0073e9SAndroid Build Coastguard Worker 2725*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 2726*da0073e9SAndroid Build Coastguard Worker def hints(x, a=0.5, b=10): 2727*da0073e9SAndroid Build Coastguard Worker # type: (Tensor, float, int) -> Tensor 2728*da0073e9SAndroid Build Coastguard Worker return x + a + b 2729*da0073e9SAndroid Build Coastguard Worker 2730*da0073e9SAndroid Build Coastguard Worker self.assertEqual(hints(torch.ones(1)), torch.ones(1) + 0.5 + 10) 2731*da0073e9SAndroid Build Coastguard Worker 2732*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Expected a default value"): 2733*da0073e9SAndroid Build Coastguard Worker 2734*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 2735*da0073e9SAndroid Build Coastguard Worker def hints_bad_types(x, a=10, b=0.5): # noqa: T484 2736*da0073e9SAndroid Build Coastguard Worker # type: (Tensor, float, int) -> Tensor 2737*da0073e9SAndroid Build Coastguard Worker return x + a + b 2738*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Expected a default value"): 2739*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 2740*da0073e9SAndroid Build Coastguard Worker def bad_no_optional(x=None): 2741*da0073e9SAndroid Build Coastguard Worker # type: (Dict[str, int]) -> Dict[str, int] 2742*da0073e9SAndroid Build Coastguard Worker return x 2743*da0073e9SAndroid Build Coastguard Worker 2744*da0073e9SAndroid Build Coastguard Worker 2745*da0073e9SAndroid Build Coastguard Worker def test_module_default_values(self): 2746*da0073e9SAndroid Build Coastguard Worker four = torch.tensor(4) 2747*da0073e9SAndroid Build Coastguard Worker 2748*da0073e9SAndroid Build Coastguard Worker class Test(torch.jit.ScriptModule): 2749*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 2750*da0073e9SAndroid Build Coastguard Worker def forward(self, input, other=four): 2751*da0073e9SAndroid Build Coastguard Worker return input + other 2752*da0073e9SAndroid Build Coastguard Worker 2753*da0073e9SAndroid Build Coastguard Worker t = Test() 2754*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t(torch.ones(1)), torch.ones(1) + 4) 2755*da0073e9SAndroid Build Coastguard Worker 2756*da0073e9SAndroid Build Coastguard Worker def test_mutable_default_values(self): 2757*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(Exception, "Mutable default parameters"): 2758*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 2759*da0073e9SAndroid Build Coastguard Worker def foo(x=(1, [])): 2760*da0073e9SAndroid Build Coastguard Worker # type: (Tuple[int, List[Tensor]]) 2761*da0073e9SAndroid Build Coastguard Worker return x 2762*da0073e9SAndroid Build Coastguard Worker 2763*da0073e9SAndroid Build Coastguard Worker class Test(torch.nn.Module): 2764*da0073e9SAndroid Build Coastguard Worker def forward(self, input=[]): # noqa: B006 2765*da0073e9SAndroid Build Coastguard Worker return input 2766*da0073e9SAndroid Build Coastguard Worker 2767*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(Exception, "Mutable default parameters"): 2768*da0073e9SAndroid Build Coastguard Worker torch.jit.script(Test()) 2769*da0073e9SAndroid Build Coastguard Worker 2770*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("TorchDynamo fails with unknown reason") 2771*da0073e9SAndroid Build Coastguard Worker def test_warnings(self): 2772*da0073e9SAndroid Build Coastguard Worker import warnings 2773*da0073e9SAndroid Build Coastguard Worker 2774*da0073e9SAndroid Build Coastguard Worker def fn(x): 2775*da0073e9SAndroid Build Coastguard Worker if bool(x < 2): 2776*da0073e9SAndroid Build Coastguard Worker warnings.warn("x is less than 2") 2777*da0073e9SAndroid Build Coastguard Worker return x 2778*da0073e9SAndroid Build Coastguard Worker 2779*da0073e9SAndroid Build Coastguard Worker class M(torch.nn.Module): 2780*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 2781*da0073e9SAndroid Build Coastguard Worker if bool(x < 2): 2782*da0073e9SAndroid Build Coastguard Worker warnings.warn("x is less than 2") 2783*da0073e9SAndroid Build Coastguard Worker return x 2784*da0073e9SAndroid Build Coastguard Worker 2785*da0073e9SAndroid Build Coastguard Worker 2786*da0073e9SAndroid Build Coastguard Worker scripted_mod = torch.jit.script(M()) 2787*da0073e9SAndroid Build Coastguard Worker scripted_fn = torch.jit.script(fn) 2788*da0073e9SAndroid Build Coastguard Worker 2789*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as warns: 2790*da0073e9SAndroid Build Coastguard Worker fn(torch.ones(1)) 2791*da0073e9SAndroid Build Coastguard Worker 2792*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as script_warns: 2793*da0073e9SAndroid Build Coastguard Worker scripted_fn(torch.ones(1)) 2794*da0073e9SAndroid Build Coastguard Worker 2795*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as script_mod_warns: 2796*da0073e9SAndroid Build Coastguard Worker scripted_mod(torch.ones(1)) 2797*da0073e9SAndroid Build Coastguard Worker 2798*da0073e9SAndroid Build Coastguard Worker self.assertEqual(str(warns[0]), str(script_warns[0])) 2799*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(script_mod_warns), 1) 2800*da0073e9SAndroid Build Coastguard Worker self.assertEqual(str(warns[0].message), str(script_mod_warns[0].message)) 2801*da0073e9SAndroid Build Coastguard Worker 2802*da0073e9SAndroid Build Coastguard Worker def test_no_erroneous_warnings(self): 2803*da0073e9SAndroid Build Coastguard Worker import warnings 2804*da0073e9SAndroid Build Coastguard Worker 2805*da0073e9SAndroid Build Coastguard Worker def fn(x): 2806*da0073e9SAndroid Build Coastguard Worker if bool(x > 0): 2807*da0073e9SAndroid Build Coastguard Worker warnings.warn('This should NOT be printed') 2808*da0073e9SAndroid Build Coastguard Worker x += 1 2809*da0073e9SAndroid Build Coastguard Worker return x 2810*da0073e9SAndroid Build Coastguard Worker 2811*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as warns: 2812*da0073e9SAndroid Build Coastguard Worker fn_script = torch.jit.script(fn) 2813*da0073e9SAndroid Build Coastguard Worker fn_script(torch.tensor(0)) 2814*da0073e9SAndroid Build Coastguard Worker warns = [str(w.message) for w in warns] 2815*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(warns), 0) 2816*da0073e9SAndroid Build Coastguard Worker 2817*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(True, "TODO: re-enable with https://github.com/pytorch/pytorch/pull/29339") 2818*da0073e9SAndroid Build Coastguard Worker def test_torch_load_error(self): 2819*da0073e9SAndroid Build Coastguard Worker class J(torch.jit.ScriptModule): 2820*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 2821*da0073e9SAndroid Build Coastguard Worker def forward(self, input): 2822*da0073e9SAndroid Build Coastguard Worker return input + 100 2823*da0073e9SAndroid Build Coastguard Worker 2824*da0073e9SAndroid Build Coastguard Worker j = J() 2825*da0073e9SAndroid Build Coastguard Worker with TemporaryFileName() as fname: 2826*da0073e9SAndroid Build Coastguard Worker j.save(fname) 2827*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "is a zip"): 2828*da0073e9SAndroid Build Coastguard Worker torch.load(fname) 2829*da0073e9SAndroid Build Coastguard Worker 2830*da0073e9SAndroid Build Coastguard Worker def test_torch_load_zipfile_check(self): 2831*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 2832*da0073e9SAndroid Build Coastguard Worker def fn(x): 2833*da0073e9SAndroid Build Coastguard Worker return x + 10 2834*da0073e9SAndroid Build Coastguard Worker 2835*da0073e9SAndroid Build Coastguard Worker with TemporaryFileName() as fname: 2836*da0073e9SAndroid Build Coastguard Worker fn.save(fname) 2837*da0073e9SAndroid Build Coastguard Worker with open(fname, 'rb') as f: 2838*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.serialization._is_zipfile(f)) 2839*da0073e9SAndroid Build Coastguard Worker 2840*da0073e9SAndroid Build Coastguard Worker def test_python_bindings(self): 2841*da0073e9SAndroid Build Coastguard Worker lstm_cell = torch.jit.script(LSTMCellS) 2842*da0073e9SAndroid Build Coastguard Worker 2843*da0073e9SAndroid Build Coastguard Worker def lstm(x, hx, cx, w_ih, w_hh, b_ih, b_hh): 2844*da0073e9SAndroid Build Coastguard Worker for i in range(x.size(0)): 2845*da0073e9SAndroid Build Coastguard Worker hx, cx = lstm_cell(x[i], hx, cx, w_ih, w_hh, b_ih, b_hh) 2846*da0073e9SAndroid Build Coastguard Worker return hx 2847*da0073e9SAndroid Build Coastguard Worker 2848*da0073e9SAndroid Build Coastguard Worker slstm = torch.jit.script(lstm) 2849*da0073e9SAndroid Build Coastguard Worker 2850*da0073e9SAndroid Build Coastguard Worker inputs = get_lstm_inputs('cpu', training=True, seq_length=10) 2851*da0073e9SAndroid Build Coastguard Worker slstm(*inputs).sum().backward() 2852*da0073e9SAndroid Build Coastguard Worker global fw_graph 2853*da0073e9SAndroid Build Coastguard Worker fw_graph = slstm.graph_for(*inputs) 2854*da0073e9SAndroid Build Coastguard Worker nodes = list(fw_graph.nodes()) 2855*da0073e9SAndroid Build Coastguard Worker tested_blocks = False 2856*da0073e9SAndroid Build Coastguard Worker for node in nodes: 2857*da0073e9SAndroid Build Coastguard Worker for output in node.outputs(): 2858*da0073e9SAndroid Build Coastguard Worker self.assertTrue(hasattr(output, 'type')) 2859*da0073e9SAndroid Build Coastguard Worker self.assertTrue(output.type() is not None) 2860*da0073e9SAndroid Build Coastguard Worker for input in node.inputs(): 2861*da0073e9SAndroid Build Coastguard Worker self.assertTrue(hasattr(input, 'type')) 2862*da0073e9SAndroid Build Coastguard Worker self.assertTrue(input.type() is not None) 2863*da0073e9SAndroid Build Coastguard Worker for block in node.blocks(): 2864*da0073e9SAndroid Build Coastguard Worker tested_blocks = True 2865*da0073e9SAndroid Build Coastguard Worker self.assertTrue(hasattr(block, 'inputs')) 2866*da0073e9SAndroid Build Coastguard Worker self.assertTrue(hasattr(block, 'outputs')) 2867*da0073e9SAndroid Build Coastguard Worker for output in block.outputs(): 2868*da0073e9SAndroid Build Coastguard Worker self.assertTrue(hasattr(output, 'type')) 2869*da0073e9SAndroid Build Coastguard Worker self.assertTrue(output.type() is not None) 2870*da0073e9SAndroid Build Coastguard Worker for input in block.inputs(): 2871*da0073e9SAndroid Build Coastguard Worker self.assertTrue(hasattr(input, 'type')) 2872*da0073e9SAndroid Build Coastguard Worker self.assertTrue(input.type() is not None) 2873*da0073e9SAndroid Build Coastguard Worker self.assertTrue(hasattr(block, 'returnNode')) 2874*da0073e9SAndroid Build Coastguard Worker self.assertTrue(type(block.returnNode()) == torch._C.Node) 2875*da0073e9SAndroid Build Coastguard Worker self.assertTrue(hasattr(block, 'paramNode')) 2876*da0073e9SAndroid Build Coastguard Worker self.assertTrue(type(block.paramNode()) == torch._C.Node) 2877*da0073e9SAndroid Build Coastguard Worker self.assertTrue(tested_blocks) 2878*da0073e9SAndroid Build Coastguard Worker 2879*da0073e9SAndroid Build Coastguard Worker def test_export_opnames(self): 2880*da0073e9SAndroid Build Coastguard Worker class Foo(torch.jit.ScriptModule): 2881*da0073e9SAndroid Build Coastguard Worker def one(self, x, y): 2882*da0073e9SAndroid Build Coastguard Worker # type: (Tensor, Tensor) -> Tensor 2883*da0073e9SAndroid Build Coastguard Worker return x + y 2884*da0073e9SAndroid Build Coastguard Worker 2885*da0073e9SAndroid Build Coastguard Worker def two(self, x): 2886*da0073e9SAndroid Build Coastguard Worker # type: (Tensor) -> Tensor 2887*da0073e9SAndroid Build Coastguard Worker return 2 * x 2888*da0073e9SAndroid Build Coastguard Worker 2889*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 2890*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 2891*da0073e9SAndroid Build Coastguard Worker # type: (Tensor) -> Tensor 2892*da0073e9SAndroid Build Coastguard Worker return self.one(self.two(x), x) 2893*da0073e9SAndroid Build Coastguard Worker 2894*da0073e9SAndroid Build Coastguard Worker class Bar(torch.jit.ScriptModule): 2895*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 2896*da0073e9SAndroid Build Coastguard Worker super().__init__() 2897*da0073e9SAndroid Build Coastguard Worker self.sub = Foo() 2898*da0073e9SAndroid Build Coastguard Worker 2899*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 2900*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 2901*da0073e9SAndroid Build Coastguard Worker # type: (Tensor) -> Tensor 2902*da0073e9SAndroid Build Coastguard Worker return self.sub.forward(x) 2903*da0073e9SAndroid Build Coastguard Worker 2904*da0073e9SAndroid Build Coastguard Worker bar = Bar() 2905*da0073e9SAndroid Build Coastguard Worker ops = torch.jit.export_opnames(bar) 2906*da0073e9SAndroid Build Coastguard Worker expected = ['aten::add.Tensor', 'aten::mul.Scalar'] 2907*da0073e9SAndroid Build Coastguard Worker self.assertTrue(set(expected).issubset(set(ops))) 2908*da0073e9SAndroid Build Coastguard Worker 2909*da0073e9SAndroid Build Coastguard Worker def test_pytorch_jit_env_off(self): 2910*da0073e9SAndroid Build Coastguard Worker import subprocess 2911*da0073e9SAndroid Build Coastguard Worker env = os.environ.copy() 2912*da0073e9SAndroid Build Coastguard Worker env['PYTORCH_JIT'] = '0' 2913*da0073e9SAndroid Build Coastguard Worker try: 2914*da0073e9SAndroid Build Coastguard Worker subprocess.check_output([sys.executable, '-c', 'import torch'], env=env) 2915*da0073e9SAndroid Build Coastguard Worker except subprocess.CalledProcessError as e: 2916*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("Could not 'import torch' with PYTORCH_JIT=0") from e 2917*da0073e9SAndroid Build Coastguard Worker 2918*da0073e9SAndroid Build Coastguard Worker def test_print_op_module(self): 2919*da0073e9SAndroid Build Coastguard Worker # Issue #19351: python2 and python3 go through different paths. 2920*da0073e9SAndroid Build Coastguard Worker # python2 returns '<module 'torch.ops' (built-in)>' 2921*da0073e9SAndroid Build Coastguard Worker # python3 uses __file__ and return 2922*da0073e9SAndroid Build Coastguard Worker # '<module 'torch.ops' from '/scratch/ailzhang/pytorch/torch/_ops.py'>' 2923*da0073e9SAndroid Build Coastguard Worker s = str(torch.ops) 2924*da0073e9SAndroid Build Coastguard Worker self.assertRegex(s, r'ops') 2925*da0073e9SAndroid Build Coastguard Worker 2926*da0073e9SAndroid Build Coastguard Worker def test_print_classes_module(self): 2927*da0073e9SAndroid Build Coastguard Worker s = str(torch.classes) 2928*da0073e9SAndroid Build Coastguard Worker self.assertRegex(s, r'classes') 2929*da0073e9SAndroid Build Coastguard Worker 2930*da0073e9SAndroid Build Coastguard Worker def test_print_torch_ops_modules(self): 2931*da0073e9SAndroid Build Coastguard Worker s = str(torch._ops.ops.quantized) 2932*da0073e9SAndroid Build Coastguard Worker self.assertRegex(s, r'torch.ops') 2933*da0073e9SAndroid Build Coastguard Worker s = str(torch._ops.ops.atan) 2934*da0073e9SAndroid Build Coastguard Worker self.assertRegex(s, r'torch.ops') 2935*da0073e9SAndroid Build Coastguard Worker 2936*da0073e9SAndroid Build Coastguard Worker def test_hide_source_ranges_context_manager(self): 2937*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 2938*da0073e9SAndroid Build Coastguard Worker def foo(x): 2939*da0073e9SAndroid Build Coastguard Worker return torch.add(x, x) 2940*da0073e9SAndroid Build Coastguard Worker 2941*da0073e9SAndroid Build Coastguard Worker graph = foo.graph 2942*da0073e9SAndroid Build Coastguard Worker source_range_regex = "# .*\\.py" 2943*da0073e9SAndroid Build Coastguard Worker self.assertRegex(graph.__repr__(), source_range_regex) 2944*da0073e9SAndroid Build Coastguard Worker with torch.jit._hide_source_ranges(): 2945*da0073e9SAndroid Build Coastguard Worker self.assertNotRegex(graph.__repr__(), source_range_regex) 2946*da0073e9SAndroid Build Coastguard Worker self.assertRegex(graph.str(print_source_ranges=True), source_range_regex) 2947*da0073e9SAndroid Build Coastguard Worker self.assertRegex(graph.__repr__(), source_range_regex) 2948*da0073e9SAndroid Build Coastguard Worker 2949*da0073e9SAndroid Build Coastguard Worker 2950*da0073e9SAndroid Build Coastguard Workerclass TestFrontend(JitTestCase): 2951*da0073e9SAndroid Build Coastguard Worker 2952*da0073e9SAndroid Build Coastguard Worker def test_instancing_error(self): 2953*da0073e9SAndroid Build Coastguard Worker @torch.jit.ignore 2954*da0073e9SAndroid Build Coastguard Worker class MyScriptClass: 2955*da0073e9SAndroid Build Coastguard Worker def unscriptable(self): 2956*da0073e9SAndroid Build Coastguard Worker return "a" + 200 2957*da0073e9SAndroid Build Coastguard Worker 2958*da0073e9SAndroid Build Coastguard Worker 2959*da0073e9SAndroid Build Coastguard Worker class TestModule(torch.nn.Module): 2960*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 2961*da0073e9SAndroid Build Coastguard Worker return MyScriptClass() 2962*da0073e9SAndroid Build Coastguard Worker 2963*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(torch.jit.frontend.FrontendError) as cm: 2964*da0073e9SAndroid Build Coastguard Worker torch.jit.script(TestModule()) 2965*da0073e9SAndroid Build Coastguard Worker 2966*da0073e9SAndroid Build Coastguard Worker checker = FileCheck() 2967*da0073e9SAndroid Build Coastguard Worker checker.check("Cannot instantiate class") 2968*da0073e9SAndroid Build Coastguard Worker checker.check("def forward") 2969*da0073e9SAndroid Build Coastguard Worker checker.run(str(cm.exception)) 2970*da0073e9SAndroid Build Coastguard Worker 2971*da0073e9SAndroid Build Coastguard Worker def test_dictionary_as_example_inputs_for_jit_trace(self): 2972*da0073e9SAndroid Build Coastguard Worker class TestModule_v1(torch.nn.Module): 2973*da0073e9SAndroid Build Coastguard Worker def forward(self, key2=None, key3=None, key4=None, key5=None, key1=None, key6=None): 2974*da0073e9SAndroid Build Coastguard Worker return key1 + key2 + key3 2975*da0073e9SAndroid Build Coastguard Worker 2976*da0073e9SAndroid Build Coastguard Worker class TestModule_v2(torch.nn.Module): 2977*da0073e9SAndroid Build Coastguard Worker def forward(self, x, y): 2978*da0073e9SAndroid Build Coastguard Worker return x + y 2979*da0073e9SAndroid Build Coastguard Worker 2980*da0073e9SAndroid Build Coastguard Worker def test_func(x, y): 2981*da0073e9SAndroid Build Coastguard Worker return x + y 2982*da0073e9SAndroid Build Coastguard Worker model_1 = TestModule_v1() 2983*da0073e9SAndroid Build Coastguard Worker model_2 = TestModule_v2() 2984*da0073e9SAndroid Build Coastguard Worker value1 = torch.ones(1) 2985*da0073e9SAndroid Build Coastguard Worker value2 = torch.ones(1) 2986*da0073e9SAndroid Build Coastguard Worker value3 = torch.ones(1) 2987*da0073e9SAndroid Build Coastguard Worker example_input_dict = {'key1': value1, 'key2': value2, 'key3': value3} 2988*da0073e9SAndroid Build Coastguard Worker example_input_dict_func = {'x': value1, 'y': value2} 2989*da0073e9SAndroid Build Coastguard Worker traced_model_1 = torch.jit.trace(model_1, example_kwarg_inputs=example_input_dict, strict=False) 2990*da0073e9SAndroid Build Coastguard Worker traced_model_1_m = torch.jit.trace_module( 2991*da0073e9SAndroid Build Coastguard Worker model_1, {'forward': example_input_dict}, example_inputs_is_kwarg=True, strict=False) 2992*da0073e9SAndroid Build Coastguard Worker traced_model_2 = torch.jit.trace(model_2, example_kwarg_inputs={'x': torch.rand([2]), 'y': torch.rand([2])}) 2993*da0073e9SAndroid Build Coastguard Worker traced_func = torch.jit.trace(test_func, example_kwarg_inputs=example_input_dict_func, strict=False) 2994*da0073e9SAndroid Build Coastguard Worker res_1 = traced_model_1(**example_input_dict) 2995*da0073e9SAndroid Build Coastguard Worker res_1_m = traced_model_1_m(**example_input_dict) 2996*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res_1, 3 * torch.ones(1)) 2997*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res_1_m, 3 * torch.ones(1)) 2998*da0073e9SAndroid Build Coastguard Worker res_func = traced_func(**example_input_dict_func) 2999*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res_func, 2 * torch.ones(1)) 3000*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r"forward\(\) is missing value for argument 'x'."): 3001*da0073e9SAndroid Build Coastguard Worker res_2 = traced_model_2(**{'z': torch.rand([2]), 'y': torch.rand([2])}) # noqa: PIE804 3002*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r"forward\(\) is missing value for argument 'y'."): 3003*da0073e9SAndroid Build Coastguard Worker res_2 = traced_model_2(**{'x': torch.rand([2]), 'z': torch.rand([2])}) # noqa: PIE804 3004*da0073e9SAndroid Build Coastguard Worker 3005*da0073e9SAndroid Build Coastguard Worker 3006*da0073e9SAndroid Build Coastguard Workerclass TestScript(JitTestCase): 3007*da0073e9SAndroid Build Coastguard Worker 3008*da0073e9SAndroid Build Coastguard Worker # Tests that calling torch.jit.script repeated on function is allowed. 3009*da0073e9SAndroid Build Coastguard Worker def test_repeated_script_on_function(self): 3010*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 3011*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 3012*da0073e9SAndroid Build Coastguard Worker def fn(x): 3013*da0073e9SAndroid Build Coastguard Worker return x 3014*da0073e9SAndroid Build Coastguard Worker 3015*da0073e9SAndroid Build Coastguard Worker torch.jit.script(torch.jit.script(fn)) 3016*da0073e9SAndroid Build Coastguard Worker 3017*da0073e9SAndroid Build Coastguard Worker def test_pretty_print_function(self): 3018*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 3019*da0073e9SAndroid Build Coastguard Worker def foo(x): 3020*da0073e9SAndroid Build Coastguard Worker return torch.nn.functional.interpolate(x) 3021*da0073e9SAndroid Build Coastguard Worker 3022*da0073e9SAndroid Build Coastguard Worker FileCheck().check("interpolate").run(foo.code) 3023*da0073e9SAndroid Build Coastguard Worker 3024*da0073e9SAndroid Build Coastguard Worker def test_inlined_graph(self): 3025*da0073e9SAndroid Build Coastguard Worker """ 3026*da0073e9SAndroid Build Coastguard Worker Check that the `inlined_graph` property correctly returns an inlined 3027*da0073e9SAndroid Build Coastguard Worker graph, both through function calls and method calls. 3028*da0073e9SAndroid Build Coastguard Worker """ 3029*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 3030*da0073e9SAndroid Build Coastguard Worker def foo(x): 3031*da0073e9SAndroid Build Coastguard Worker return torch.add(x, x) 3032*da0073e9SAndroid Build Coastguard Worker 3033*da0073e9SAndroid Build Coastguard Worker class MyNestedMod(torch.nn.Module): 3034*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 3035*da0073e9SAndroid Build Coastguard Worker return torch.sub(x, x) 3036*da0073e9SAndroid Build Coastguard Worker 3037*da0073e9SAndroid Build Coastguard Worker 3038*da0073e9SAndroid Build Coastguard Worker class MyMod(torch.nn.Module): 3039*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 3040*da0073e9SAndroid Build Coastguard Worker super().__init__() 3041*da0073e9SAndroid Build Coastguard Worker self.nested = MyNestedMod() 3042*da0073e9SAndroid Build Coastguard Worker 3043*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 3044*da0073e9SAndroid Build Coastguard Worker x = self.nested(x) # sub 3045*da0073e9SAndroid Build Coastguard Worker x = foo(x) # add 3046*da0073e9SAndroid Build Coastguard Worker return torch.mul(x, x) 3047*da0073e9SAndroid Build Coastguard Worker 3048*da0073e9SAndroid Build Coastguard Worker m = torch.jit.script(MyMod()) 3049*da0073e9SAndroid Build Coastguard Worker FileCheck().check("aten::sub") \ 3050*da0073e9SAndroid Build Coastguard Worker .check("aten::add") \ 3051*da0073e9SAndroid Build Coastguard Worker .check("aten::mul") \ 3052*da0073e9SAndroid Build Coastguard Worker .run(m.inlined_graph) 3053*da0073e9SAndroid Build Coastguard Worker 3054*da0073e9SAndroid Build Coastguard Worker def test_static_method_on_module(self): 3055*da0073e9SAndroid Build Coastguard Worker """ 3056*da0073e9SAndroid Build Coastguard Worker Check that the `@staticmethod` annotation on a function on a module works. 3057*da0073e9SAndroid Build Coastguard Worker """ 3058*da0073e9SAndroid Build Coastguard Worker class MyCell(torch.nn.Module): 3059*da0073e9SAndroid Build Coastguard Worker @staticmethod 3060*da0073e9SAndroid Build Coastguard Worker def do_it(x, h): 3061*da0073e9SAndroid Build Coastguard Worker new_h = torch.tanh(x + h) 3062*da0073e9SAndroid Build Coastguard Worker return new_h, new_h 3063*da0073e9SAndroid Build Coastguard Worker 3064*da0073e9SAndroid Build Coastguard Worker def forward(self, x, h): 3065*da0073e9SAndroid Build Coastguard Worker return self.do_it(x, h) 3066*da0073e9SAndroid Build Coastguard Worker 3067*da0073e9SAndroid Build Coastguard Worker my_cell = torch.jit.script(MyCell()) 3068*da0073e9SAndroid Build Coastguard Worker x = torch.rand(3, 4) 3069*da0073e9SAndroid Build Coastguard Worker h = torch.rand(3, 4) 3070*da0073e9SAndroid Build Coastguard Worker jitted_cell = my_cell(x, h) 3071*da0073e9SAndroid Build Coastguard Worker non_jitted_cell = MyCell().do_it(x, h) 3072*da0073e9SAndroid Build Coastguard Worker 3073*da0073e9SAndroid Build Coastguard Worker self.assertEqual(jitted_cell, non_jitted_cell) 3074*da0073e9SAndroid Build Coastguard Worker 3075*da0073e9SAndroid Build Coastguard Worker def test_code_with_constants(self): 3076*da0073e9SAndroid Build Coastguard Worker """ 3077*da0073e9SAndroid Build Coastguard Worker Check that the `code_with_constants` property correctly returns graph CONSTANTS in the 3078*da0073e9SAndroid Build Coastguard Worker CONSTANTS.cN format used in the output of the `code` property. 3079*da0073e9SAndroid Build Coastguard Worker """ 3080*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 3081*da0073e9SAndroid Build Coastguard Worker def foo(x=torch.ones(1)): 3082*da0073e9SAndroid Build Coastguard Worker return x 3083*da0073e9SAndroid Build Coastguard Worker 3084*da0073e9SAndroid Build Coastguard Worker class Moddy(torch.nn.Module): 3085*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 3086*da0073e9SAndroid Build Coastguard Worker return foo() 3087*da0073e9SAndroid Build Coastguard Worker 3088*da0073e9SAndroid Build Coastguard Worker m = torch.jit.script(Moddy()) 3089*da0073e9SAndroid Build Coastguard Worker src, CONSTANTS = m.code_with_constants 3090*da0073e9SAndroid Build Coastguard Worker 3091*da0073e9SAndroid Build Coastguard Worker self.assertEqual(CONSTANTS.c0, torch.ones(1)) 3092*da0073e9SAndroid Build Coastguard Worker self.assertEqual(src, m.code) 3093*da0073e9SAndroid Build Coastguard Worker 3094*da0073e9SAndroid Build Coastguard Worker def test_code_with_constants_restore(self): 3095*da0073e9SAndroid Build Coastguard Worker """ 3096*da0073e9SAndroid Build Coastguard Worker Check that the `code_with_constants` property correctly works on restoration after save() + load() 3097*da0073e9SAndroid Build Coastguard Worker """ 3098*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 3099*da0073e9SAndroid Build Coastguard Worker def foo(x=torch.ones(1)): 3100*da0073e9SAndroid Build Coastguard Worker return x 3101*da0073e9SAndroid Build Coastguard Worker 3102*da0073e9SAndroid Build Coastguard Worker class Moddy(torch.nn.Module): 3103*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 3104*da0073e9SAndroid Build Coastguard Worker return foo() 3105*da0073e9SAndroid Build Coastguard Worker 3106*da0073e9SAndroid Build Coastguard Worker m = torch.jit.script(Moddy()) 3107*da0073e9SAndroid Build Coastguard Worker src, CONSTANTS = m.code_with_constants 3108*da0073e9SAndroid Build Coastguard Worker eic = self.getExportImportCopy(m) 3109*da0073e9SAndroid Build Coastguard Worker 3110*da0073e9SAndroid Build Coastguard Worker src_eic, CONSTANTS_eic = eic.code_with_constants 3111*da0073e9SAndroid Build Coastguard Worker 3112*da0073e9SAndroid Build Coastguard Worker self.assertEqual(src, src_eic) 3113*da0073e9SAndroid Build Coastguard Worker self.assertEqual(CONSTANTS.c0, CONSTANTS_eic.c0) 3114*da0073e9SAndroid Build Coastguard Worker 3115*da0073e9SAndroid Build Coastguard Worker 3116*da0073e9SAndroid Build Coastguard Worker def test_oneline_func(self): 3117*da0073e9SAndroid Build Coastguard Worker def fn(x): return x # noqa: E704 3118*da0073e9SAndroid Build Coastguard Worker 3119*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, (torch.ones(2, 2), )) 3120*da0073e9SAndroid Build Coastguard Worker 3121*da0073e9SAndroid Build Coastguard Worker def test_request_bailout(self): 3122*da0073e9SAndroid Build Coastguard Worker with enable_profiling_mode_for_profiling_tests(): 3123*da0073e9SAndroid Build Coastguard Worker 3124*da0073e9SAndroid Build Coastguard Worker def fct_loop(x): 3125*da0073e9SAndroid Build Coastguard Worker for i in range(3): 3126*da0073e9SAndroid Build Coastguard Worker x = torch.cat((x, x), 0) 3127*da0073e9SAndroid Build Coastguard Worker return x 3128*da0073e9SAndroid Build Coastguard Worker 3129*da0073e9SAndroid Build Coastguard Worker x = torch.ones(2, 3, 4, dtype=torch.float32) 3130*da0073e9SAndroid Build Coastguard Worker expected = fct_loop(x) 3131*da0073e9SAndroid Build Coastguard Worker jitted = torch.jit.script(fct_loop) 3132*da0073e9SAndroid Build Coastguard Worker # profile 3133*da0073e9SAndroid Build Coastguard Worker jitted(x) 3134*da0073e9SAndroid Build Coastguard Worker # optimize 3135*da0073e9SAndroid Build Coastguard Worker jitted(x) 3136*da0073e9SAndroid Build Coastguard Worker dstate = jitted.get_debug_state() 3137*da0073e9SAndroid Build Coastguard Worker eplan = get_execution_plan(dstate) 3138*da0073e9SAndroid Build Coastguard Worker num_bailouts = eplan.code.num_bailouts() 3139*da0073e9SAndroid Build Coastguard Worker 3140*da0073e9SAndroid Build Coastguard Worker for i in range(0, num_bailouts): 3141*da0073e9SAndroid Build Coastguard Worker eplan.code.request_bailout(i) 3142*da0073e9SAndroid Build Coastguard Worker self.assertEqual(jitted(x), expected) 3143*da0073e9SAndroid Build Coastguard Worker 3144*da0073e9SAndroid Build Coastguard Worker @unittest.skip("bailouts are being deprecated") 3145*da0073e9SAndroid Build Coastguard Worker def test_dominated_bailout(self): 3146*da0073e9SAndroid Build Coastguard Worker with enable_profiling_mode_for_profiling_tests(): 3147*da0073e9SAndroid Build Coastguard Worker # functional dominated guard 3148*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 3149*da0073e9SAndroid Build Coastguard Worker def foo(x): 3150*da0073e9SAndroid Build Coastguard Worker dim = x.dim() 3151*da0073e9SAndroid Build Coastguard Worker if dim == 0: 3152*da0073e9SAndroid Build Coastguard Worker y = int(x) 3153*da0073e9SAndroid Build Coastguard Worker else: 3154*da0073e9SAndroid Build Coastguard Worker y = x.size()[dim - 1] 3155*da0073e9SAndroid Build Coastguard Worker return y 3156*da0073e9SAndroid Build Coastguard Worker 3157*da0073e9SAndroid Build Coastguard Worker x = torch.zeros(2) 3158*da0073e9SAndroid Build Coastguard Worker self.assertEqual(foo(x), 2) 3159*da0073e9SAndroid Build Coastguard Worker self.assertEqual(foo(x), 2) 3160*da0073e9SAndroid Build Coastguard Worker g = torch.jit.last_executed_optimized_graph() 3161*da0073e9SAndroid Build Coastguard Worker g_s = str(g) 3162*da0073e9SAndroid Build Coastguard Worker g_s = g_s[0:g_s.find("return")] 3163*da0073e9SAndroid Build Coastguard Worker FileCheck().check_count("prim::BailOut[", 1, exactly=True).run(g_s) 3164*da0073e9SAndroid Build Coastguard Worker 3165*da0073e9SAndroid Build Coastguard Worker # dominated guard of non-functional value 3166*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 3167*da0073e9SAndroid Build Coastguard Worker def foo(x): 3168*da0073e9SAndroid Build Coastguard Worker dim = x.dim() 3169*da0073e9SAndroid Build Coastguard Worker x.add_(3) 3170*da0073e9SAndroid Build Coastguard Worker if dim == 0: 3171*da0073e9SAndroid Build Coastguard Worker return 0 3172*da0073e9SAndroid Build Coastguard Worker else: 3173*da0073e9SAndroid Build Coastguard Worker return x.size()[dim - 1] 3174*da0073e9SAndroid Build Coastguard Worker 3175*da0073e9SAndroid Build Coastguard Worker x = torch.zeros(2) 3176*da0073e9SAndroid Build Coastguard Worker self.assertEqual(foo(x), 2) 3177*da0073e9SAndroid Build Coastguard Worker self.assertEqual(foo(x), 2) 3178*da0073e9SAndroid Build Coastguard Worker g = torch.jit.last_executed_optimized_graph() 3179*da0073e9SAndroid Build Coastguard Worker FileCheck().check("prim::BailOut[").check("aten::add_").check_next("prim::BailOut[").check("return").run(g) 3180*da0073e9SAndroid Build Coastguard Worker 3181*da0073e9SAndroid Build Coastguard Worker with torch.enable_grad(): 3182*da0073e9SAndroid Build Coastguard Worker @torch.jit.ignore 3183*da0073e9SAndroid Build Coastguard Worker def disable_grad(): 3184*da0073e9SAndroid Build Coastguard Worker torch.set_grad_enabled(False) 3185*da0073e9SAndroid Build Coastguard Worker 3186*da0073e9SAndroid Build Coastguard Worker @torch.jit.ignore 3187*da0073e9SAndroid Build Coastguard Worker def enable_grad(): 3188*da0073e9SAndroid Build Coastguard Worker torch.set_grad_enabled(True) 3189*da0073e9SAndroid Build Coastguard Worker 3190*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 3191*da0073e9SAndroid Build Coastguard Worker def foo(x): 3192*da0073e9SAndroid Build Coastguard Worker x = x + 1 3193*da0073e9SAndroid Build Coastguard Worker dim = x.dim() 3194*da0073e9SAndroid Build Coastguard Worker disable_grad() 3195*da0073e9SAndroid Build Coastguard Worker if dim == 0: 3196*da0073e9SAndroid Build Coastguard Worker y = int(x) 3197*da0073e9SAndroid Build Coastguard Worker else: 3198*da0073e9SAndroid Build Coastguard Worker y = x.size()[dim - 1] 3199*da0073e9SAndroid Build Coastguard Worker enable_grad() 3200*da0073e9SAndroid Build Coastguard Worker return y 3201*da0073e9SAndroid Build Coastguard Worker 3202*da0073e9SAndroid Build Coastguard Worker x = torch.zeros(2, requires_grad=True) 3203*da0073e9SAndroid Build Coastguard Worker self.assertEqual(foo(x), 2) 3204*da0073e9SAndroid Build Coastguard Worker self.assertEqual(foo(x), 2) 3205*da0073e9SAndroid Build Coastguard Worker g = torch.jit.last_executed_optimized_graph() 3206*da0073e9SAndroid Build Coastguard Worker # there should still be a Bailout after disable_grad call 3207*da0073e9SAndroid Build Coastguard Worker FileCheck().check("disable_grad").check("BailOut[").check("BailoutTemplate").run(g) 3208*da0073e9SAndroid Build Coastguard Worker 3209*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("Torchdynamo cannot correctly handle profiler.profile calls") 3210*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "skip if profiling isn't enabled") 3211*da0073e9SAndroid Build Coastguard Worker def test_profiling_merge(self): 3212*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 3213*da0073e9SAndroid Build Coastguard Worker def test_not_const(x): 3214*da0073e9SAndroid Build Coastguard Worker if x.size(0) == 1: 3215*da0073e9SAndroid Build Coastguard Worker return 1 3216*da0073e9SAndroid Build Coastguard Worker else: 3217*da0073e9SAndroid Build Coastguard Worker return 2 3218*da0073e9SAndroid Build Coastguard Worker 3219*da0073e9SAndroid Build Coastguard Worker with enable_profiling_mode_for_profiling_tests(): 3220*da0073e9SAndroid Build Coastguard Worker with num_profiled_runs(2): 3221*da0073e9SAndroid Build Coastguard Worker test_not_const(torch.rand([1, 2])) 3222*da0073e9SAndroid Build Coastguard Worker test_not_const(torch.rand([2, 2])) 3223*da0073e9SAndroid Build Coastguard Worker 3224*da0073e9SAndroid Build Coastguard Worker graph_str = torch.jit.last_executed_optimized_graph() 3225*da0073e9SAndroid Build Coastguard Worker FileCheck().check("profiled_type=Float(*, 2, strides=[2, 1], requires_grad=0, device=cpu").run(graph_str) 3226*da0073e9SAndroid Build Coastguard Worker FileCheck().check_not("profiled_type=Float(1, 2, strides=[2, 1], requires_grad=0, device=cpu").run(graph_str) 3227*da0073e9SAndroid Build Coastguard Worker 3228*da0073e9SAndroid Build Coastguard Worker 3229*da0073e9SAndroid Build Coastguard Worker def test_nested_bailouts(self): 3230*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 3231*da0073e9SAndroid Build Coastguard Worker def fct_loop(x): 3232*da0073e9SAndroid Build Coastguard Worker for i in range(3): 3233*da0073e9SAndroid Build Coastguard Worker x = torch.cat((x, x), 0) 3234*da0073e9SAndroid Build Coastguard Worker return x 3235*da0073e9SAndroid Build Coastguard Worker 3236*da0073e9SAndroid Build Coastguard Worker x = torch.ones(2, 3, 4, dtype=torch.float32) 3237*da0073e9SAndroid Build Coastguard Worker out = fct_loop(x) 3238*da0073e9SAndroid Build Coastguard Worker jit_trace = torch.jit.trace(fct_loop, x) 3239*da0073e9SAndroid Build Coastguard Worker out_trace = jit_trace(x) 3240*da0073e9SAndroid Build Coastguard Worker 3241*da0073e9SAndroid Build Coastguard Worker def test_no_self_arg_ignore_function(self): 3242*da0073e9SAndroid Build Coastguard Worker class MyModule(nn.Module): 3243*da0073e9SAndroid Build Coastguard Worker @torch.jit.ignore # noqa: B902 3244*da0073e9SAndroid Build Coastguard Worker def call_np(): # noqa: B902 3245*da0073e9SAndroid Build Coastguard Worker # type: () -> int 3246*da0073e9SAndroid Build Coastguard Worker return np.random.choice(2, p=[.95, .05]) 3247*da0073e9SAndroid Build Coastguard Worker 3248*da0073e9SAndroid Build Coastguard Worker def forward(self): 3249*da0073e9SAndroid Build Coastguard Worker return self.call_np() 3250*da0073e9SAndroid Build Coastguard Worker 3251*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(Exception, "does not have a self argument"): 3252*da0073e9SAndroid Build Coastguard Worker torch.jit.script(MyModule()) 3253*da0073e9SAndroid Build Coastguard Worker 3254*da0073e9SAndroid Build Coastguard Worker def test_loop_liveness(self): 3255*da0073e9SAndroid Build Coastguard Worker with enable_profiling_mode_for_profiling_tests(): 3256*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 3257*da0073e9SAndroid Build Coastguard Worker def f(i): 3258*da0073e9SAndroid Build Coastguard Worker # type: (int) -> Tensor 3259*da0073e9SAndroid Build Coastguard Worker l = [] 3260*da0073e9SAndroid Build Coastguard Worker for n in [2, 1]: 3261*da0073e9SAndroid Build Coastguard Worker l.append(torch.zeros(n, i)) 3262*da0073e9SAndroid Build Coastguard Worker 3263*da0073e9SAndroid Build Coastguard Worker return l[0] 3264*da0073e9SAndroid Build Coastguard Worker 3265*da0073e9SAndroid Build Coastguard Worker f(2) 3266*da0073e9SAndroid Build Coastguard Worker f(1) 3267*da0073e9SAndroid Build Coastguard Worker 3268*da0073e9SAndroid Build Coastguard Worker def test_bailout_loop_carried_deps_name_clash(self): 3269*da0073e9SAndroid Build Coastguard Worker with enable_profiling_mode_for_profiling_tests(): 3270*da0073e9SAndroid Build Coastguard Worker NUM_ITERATIONS = 10 3271*da0073e9SAndroid Build Coastguard Worker 3272*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 3273*da0073e9SAndroid Build Coastguard Worker def fct_loop(z, size): 3274*da0073e9SAndroid Build Coastguard Worker # type: (int, int) -> Tuple[Tensor, List[int]] 3275*da0073e9SAndroid Build Coastguard Worker counters = torch.jit.annotate(List[int], []) 3276*da0073e9SAndroid Build Coastguard Worker j = 0 3277*da0073e9SAndroid Build Coastguard Worker y = torch.ones(2) 3278*da0073e9SAndroid Build Coastguard Worker for i in range(size): 3279*da0073e9SAndroid Build Coastguard Worker counters.append(i + j) 3280*da0073e9SAndroid Build Coastguard Worker y = torch.cat((y, torch.ones(z)), 0) 3281*da0073e9SAndroid Build Coastguard Worker j = j + 1 3282*da0073e9SAndroid Build Coastguard Worker return y, counters 3283*da0073e9SAndroid Build Coastguard Worker 3284*da0073e9SAndroid Build Coastguard Worker inputs = [1, 2, 3, 4] 3285*da0073e9SAndroid Build Coastguard Worker expected = [x * 2 for x in range(NUM_ITERATIONS)] 3286*da0073e9SAndroid Build Coastguard Worker for inp in inputs: 3287*da0073e9SAndroid Build Coastguard Worker results = fct_loop(inp, NUM_ITERATIONS) 3288*da0073e9SAndroid Build Coastguard Worker self.assertEqual(results[1], expected) 3289*da0073e9SAndroid Build Coastguard Worker 3290*da0073e9SAndroid Build Coastguard Worker def test_bailout_loop_counter_transition(self): 3291*da0073e9SAndroid Build Coastguard Worker with enable_profiling_mode_for_profiling_tests(): 3292*da0073e9SAndroid Build Coastguard Worker NUM_ITERATIONS = 10 3293*da0073e9SAndroid Build Coastguard Worker 3294*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 3295*da0073e9SAndroid Build Coastguard Worker def fct_loop(z, size): 3296*da0073e9SAndroid Build Coastguard Worker # type: (int, int) -> Tuple[Tensor, List[int]] 3297*da0073e9SAndroid Build Coastguard Worker counters = torch.jit.annotate(List[int], []) 3298*da0073e9SAndroid Build Coastguard Worker y = torch.ones(2) 3299*da0073e9SAndroid Build Coastguard Worker for i in range(size): 3300*da0073e9SAndroid Build Coastguard Worker counters.append(i) 3301*da0073e9SAndroid Build Coastguard Worker y = torch.cat((y, torch.ones(z)), 0) 3302*da0073e9SAndroid Build Coastguard Worker return y, counters 3303*da0073e9SAndroid Build Coastguard Worker 3304*da0073e9SAndroid Build Coastguard Worker inputs = [1, 2, 3, 4] 3305*da0073e9SAndroid Build Coastguard Worker expected = list(range(NUM_ITERATIONS)) 3306*da0073e9SAndroid Build Coastguard Worker for inp in inputs: 3307*da0073e9SAndroid Build Coastguard Worker results = fct_loop(inp, NUM_ITERATIONS) 3308*da0073e9SAndroid Build Coastguard Worker self.assertEqual(results[1], expected) 3309*da0073e9SAndroid Build Coastguard Worker 3310*da0073e9SAndroid Build Coastguard Worker def test_ignored_method_binding(self): 3311*da0073e9SAndroid Build Coastguard Worker class Bar(torch.nn.Module): 3312*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 3313*da0073e9SAndroid Build Coastguard Worker super().__init__() 3314*da0073e9SAndroid Build Coastguard Worker self.x : int = 0 3315*da0073e9SAndroid Build Coastguard Worker 3316*da0073e9SAndroid Build Coastguard Worker @torch.jit.export 3317*da0073e9SAndroid Build Coastguard Worker def setx(self, x : int): 3318*da0073e9SAndroid Build Coastguard Worker self.x = x 3319*da0073e9SAndroid Build Coastguard Worker 3320*da0073e9SAndroid Build Coastguard Worker @torch.jit.export 3321*da0073e9SAndroid Build Coastguard Worker def getx(self): 3322*da0073e9SAndroid Build Coastguard Worker return self.x 3323*da0073e9SAndroid Build Coastguard Worker 3324*da0073e9SAndroid Build Coastguard Worker @torch.jit.ignore 3325*da0073e9SAndroid Build Coastguard Worker def ignored_getx(self): 3326*da0073e9SAndroid Build Coastguard Worker return self.x 3327*da0073e9SAndroid Build Coastguard Worker 3328*da0073e9SAndroid Build Coastguard Worker b = Bar() 3329*da0073e9SAndroid Build Coastguard Worker b.setx(123) 3330*da0073e9SAndroid Build Coastguard Worker sb = torch.jit.script(b) 3331*da0073e9SAndroid Build Coastguard Worker self.assertEqual(sb.getx(), 123) 3332*da0073e9SAndroid Build Coastguard Worker self.assertEqual(sb.ignored_getx(), 123) 3333*da0073e9SAndroid Build Coastguard Worker 3334*da0073e9SAndroid Build Coastguard Worker sb.setx(456) 3335*da0073e9SAndroid Build Coastguard Worker self.assertEqual(sb.getx(), 456) 3336*da0073e9SAndroid Build Coastguard Worker self.assertEqual(sb.ignored_getx(), 456) 3337*da0073e9SAndroid Build Coastguard Worker 3338*da0073e9SAndroid Build Coastguard Worker def test_set_attribute_through_optional(self): 3339*da0073e9SAndroid Build Coastguard Worker class A(torch.nn.Module): 3340*da0073e9SAndroid Build Coastguard Worker __annotations__ = {"x": Optional[torch.Tensor]} 3341*da0073e9SAndroid Build Coastguard Worker 3342*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 3343*da0073e9SAndroid Build Coastguard Worker super().__init__() 3344*da0073e9SAndroid Build Coastguard Worker self.x = None 3345*da0073e9SAndroid Build Coastguard Worker 3346*da0073e9SAndroid Build Coastguard Worker @torch.jit.ignore 3347*da0073e9SAndroid Build Coastguard Worker def foo(self): 3348*da0073e9SAndroid Build Coastguard Worker if self.x is None: 3349*da0073e9SAndroid Build Coastguard Worker self.x = torch.tensor([3]) 3350*da0073e9SAndroid Build Coastguard Worker return self.x 3351*da0073e9SAndroid Build Coastguard Worker 3352*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 3353*da0073e9SAndroid Build Coastguard Worker a = self.foo() 3354*da0073e9SAndroid Build Coastguard Worker return x + 1 3355*da0073e9SAndroid Build Coastguard Worker 3356*da0073e9SAndroid Build Coastguard Worker m = torch.jit.script(A()) 3357*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m.x, None) 3358*da0073e9SAndroid Build Coastguard Worker m(torch.rand(1)) 3359*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m.x, torch.tensor([3])) 3360*da0073e9SAndroid Build Coastguard Worker 3361*da0073e9SAndroid Build Coastguard Worker def test_mutate_constant(self): 3362*da0073e9SAndroid Build Coastguard Worker class M(torch.jit.ScriptModule): 3363*da0073e9SAndroid Build Coastguard Worker __constants__ = ["foo"] 3364*da0073e9SAndroid Build Coastguard Worker 3365*da0073e9SAndroid Build Coastguard Worker def __init__(self, foo): 3366*da0073e9SAndroid Build Coastguard Worker super().__init__() 3367*da0073e9SAndroid Build Coastguard Worker self.foo = foo 3368*da0073e9SAndroid Build Coastguard Worker 3369*da0073e9SAndroid Build Coastguard Worker m = M(5) 3370*da0073e9SAndroid Build Coastguard Worker # m has a constant attribute, but we can't 3371*da0073e9SAndroid Build Coastguard Worker # assign to it 3372*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 3373*da0073e9SAndroid Build Coastguard Worker m.foo = 6 3374*da0073e9SAndroid Build Coastguard Worker 3375*da0073e9SAndroid Build Coastguard Worker def test_class_attribute(self): 3376*da0073e9SAndroid Build Coastguard Worker class M(torch.jit.ScriptModule): 3377*da0073e9SAndroid Build Coastguard Worker FOO = 0 3378*da0073e9SAndroid Build Coastguard Worker 3379*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 3380*da0073e9SAndroid Build Coastguard Worker super().__init__() 3381*da0073e9SAndroid Build Coastguard Worker self.foo = self.FOO 3382*da0073e9SAndroid Build Coastguard Worker m = M() 3383*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m.foo, M.FOO) 3384*da0073e9SAndroid Build Coastguard Worker 3385*da0073e9SAndroid Build Coastguard Worker def test_class_attribute_in_script(self): 3386*da0073e9SAndroid Build Coastguard Worker class M(torch.jit.ScriptModule): 3387*da0073e9SAndroid Build Coastguard Worker FOO = 0 3388*da0073e9SAndroid Build Coastguard Worker 3389*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 3390*da0073e9SAndroid Build Coastguard Worker def forward(self): 3391*da0073e9SAndroid Build Coastguard Worker return self.FOO 3392*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 3393*da0073e9SAndroid Build Coastguard Worker M() 3394*da0073e9SAndroid Build Coastguard Worker 3395*da0073e9SAndroid Build Coastguard Worker def test_not_initialized_err(self): 3396*da0073e9SAndroid Build Coastguard Worker class M(torch.jit.ScriptModule): 3397*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 3398*da0073e9SAndroid Build Coastguard Worker self.foo = torch.rand(2, 3) 3399*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 3400*da0073e9SAndroid Build Coastguard Worker M() 3401*da0073e9SAndroid Build Coastguard Worker 3402*da0073e9SAndroid Build Coastguard Worker def test_attribute_in_init(self): 3403*da0073e9SAndroid Build Coastguard Worker class M(torch.jit.ScriptModule): 3404*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 3405*da0073e9SAndroid Build Coastguard Worker super().__init__() 3406*da0073e9SAndroid Build Coastguard Worker self.foo = torch.jit.Attribute(0.1, float) 3407*da0073e9SAndroid Build Coastguard Worker # we should be able to use self.foo as a float here 3408*da0073e9SAndroid Build Coastguard Worker assert 0.0 < self.foo 3409*da0073e9SAndroid Build Coastguard Worker M() 3410*da0073e9SAndroid Build Coastguard Worker 3411*da0073e9SAndroid Build Coastguard Worker def test_scriptable_fn_as_attr(self): 3412*da0073e9SAndroid Build Coastguard Worker class M(torch.nn.Module): 3413*da0073e9SAndroid Build Coastguard Worker def __init__(self, fn): 3414*da0073e9SAndroid Build Coastguard Worker super().__init__() 3415*da0073e9SAndroid Build Coastguard Worker self.fn = fn 3416*da0073e9SAndroid Build Coastguard Worker 3417*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 3418*da0073e9SAndroid Build Coastguard Worker return self.fn(x) 3419*da0073e9SAndroid Build Coastguard Worker 3420*da0073e9SAndroid Build Coastguard Worker m = M(torch.sigmoid) 3421*da0073e9SAndroid Build Coastguard Worker inp = torch.rand(2, 3) 3422*da0073e9SAndroid Build Coastguard Worker self.checkModule(m, (inp, )) 3423*da0073e9SAndroid Build Coastguard Worker 3424*da0073e9SAndroid Build Coastguard Worker def test_sequence_parsing(self): 3425*da0073e9SAndroid Build Coastguard Worker tests = [ 3426*da0073e9SAndroid Build Coastguard Worker ("return [x, x,]", True), 3427*da0073e9SAndroid Build Coastguard Worker ("return [x x]", "expected ]"), 3428*da0073e9SAndroid Build Coastguard Worker ("return x, x,", True), 3429*da0073e9SAndroid Build Coastguard Worker ("return bar(x, x,)", True), 3430*da0073e9SAndroid Build Coastguard Worker ("return bar()", "Argument x not provided"), 3431*da0073e9SAndroid Build Coastguard Worker ("for a, b, in x, x,:\n pass", "List of iterables"), 3432*da0073e9SAndroid Build Coastguard Worker ("a, b, = x, x,\n return a + b", True) 3433*da0073e9SAndroid Build Coastguard Worker ] 3434*da0073e9SAndroid Build Coastguard Worker for exp, result in tests: 3435*da0073e9SAndroid Build Coastguard Worker cu = torch.jit.CompilationUnit() 3436*da0073e9SAndroid Build Coastguard Worker full = f""" 3437*da0073e9SAndroid Build Coastguard Workerdef bar(x, y): 3438*da0073e9SAndroid Build Coastguard Worker return x + y 3439*da0073e9SAndroid Build Coastguard Workerdef foo(x): 3440*da0073e9SAndroid Build Coastguard Worker {exp} 3441*da0073e9SAndroid Build Coastguard Worker """ 3442*da0073e9SAndroid Build Coastguard Worker if isinstance(result, str): 3443*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, result): 3444*da0073e9SAndroid Build Coastguard Worker cu.define(full) 3445*da0073e9SAndroid Build Coastguard Worker else: 3446*da0073e9SAndroid Build Coastguard Worker cu.define(full) 3447*da0073e9SAndroid Build Coastguard Worker 3448*da0073e9SAndroid Build Coastguard Worker def test_namedtuple_python(self): 3449*da0073e9SAndroid Build Coastguard Worker global MyTuple, MyMod # see [local resolution in python] 3450*da0073e9SAndroid Build Coastguard Worker MyTuple = namedtuple('MyTuple', ['a']) 3451*da0073e9SAndroid Build Coastguard Worker 3452*da0073e9SAndroid Build Coastguard Worker @torch.jit.unused 3453*da0073e9SAndroid Build Coastguard Worker def fn(): 3454*da0073e9SAndroid Build Coastguard Worker # type: () -> MyTuple 3455*da0073e9SAndroid Build Coastguard Worker return MyTuple(1) 3456*da0073e9SAndroid Build Coastguard Worker 3457*da0073e9SAndroid Build Coastguard Worker # Only check compilation 3458*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 3459*da0073e9SAndroid Build Coastguard Worker def fn2(): 3460*da0073e9SAndroid Build Coastguard Worker # type: () -> MyTuple 3461*da0073e9SAndroid Build Coastguard Worker return fn() 3462*da0073e9SAndroid Build Coastguard Worker 3463*da0073e9SAndroid Build Coastguard Worker FileCheck().check("NamedTuple").run(fn2.graph) 3464*da0073e9SAndroid Build Coastguard Worker 3465*da0073e9SAndroid Build Coastguard Worker class MyMod(torch.nn.Module): 3466*da0073e9SAndroid Build Coastguard Worker @torch.jit.unused 3467*da0073e9SAndroid Build Coastguard Worker def fn(self): 3468*da0073e9SAndroid Build Coastguard Worker # type: () -> MyTuple 3469*da0073e9SAndroid Build Coastguard Worker return MyTuple(1) 3470*da0073e9SAndroid Build Coastguard Worker 3471*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 3472*da0073e9SAndroid Build Coastguard Worker if 1 == 1: 3473*da0073e9SAndroid Build Coastguard Worker return MyTuple(torch.rand(2, 3)) 3474*da0073e9SAndroid Build Coastguard Worker else: 3475*da0073e9SAndroid Build Coastguard Worker return self.fn() 3476*da0073e9SAndroid Build Coastguard Worker 3477*da0073e9SAndroid Build Coastguard Worker # shouldn't throw a type error 3478*da0073e9SAndroid Build Coastguard Worker torch.jit.script(MyMod()) 3479*da0073e9SAndroid Build Coastguard Worker 3480*da0073e9SAndroid Build Coastguard Worker def test_unused_decorator(self): 3481*da0073e9SAndroid Build Coastguard Worker class MyMod(torch.nn.Module): 3482*da0073e9SAndroid Build Coastguard Worker @torch.jit.unused 3483*da0073e9SAndroid Build Coastguard Worker @torch.no_grad() 3484*da0073e9SAndroid Build Coastguard Worker def fn(self, x): 3485*da0073e9SAndroid Build Coastguard Worker # type: (Tensor) -> int 3486*da0073e9SAndroid Build Coastguard Worker return next(x) # invalid, but should be ignored 3487*da0073e9SAndroid Build Coastguard Worker 3488*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 3489*da0073e9SAndroid Build Coastguard Worker return self.fn(x) 3490*da0073e9SAndroid Build Coastguard Worker 3491*da0073e9SAndroid Build Coastguard Worker torch.jit.script(MyMod()) 3492*da0073e9SAndroid Build Coastguard Worker 3493*da0073e9SAndroid Build Coastguard Worker @_inline_everything 3494*da0073e9SAndroid Build Coastguard Worker def test_lazy_script(self): 3495*da0073e9SAndroid Build Coastguard Worker def untraceable(x): 3496*da0073e9SAndroid Build Coastguard Worker if x.ndim > 2: 3497*da0073e9SAndroid Build Coastguard Worker print("hello") 3498*da0073e9SAndroid Build Coastguard Worker else: 3499*da0073e9SAndroid Build Coastguard Worker print("goodbye") 3500*da0073e9SAndroid Build Coastguard Worker return x + 2 3501*da0073e9SAndroid Build Coastguard Worker 3502*da0073e9SAndroid Build Coastguard Worker # Non-working example 3503*da0073e9SAndroid Build Coastguard Worker def fn(x): 3504*da0073e9SAndroid Build Coastguard Worker return untraceable(x) 3505*da0073e9SAndroid Build Coastguard Worker 3506*da0073e9SAndroid Build Coastguard Worker with self.capture_stdout(): 3507*da0073e9SAndroid Build Coastguard Worker traced_bad = torch.jit.trace(fn, [torch.ones(2, 2)]) 3508*da0073e9SAndroid Build Coastguard Worker 3509*da0073e9SAndroid Build Coastguard Worker FileCheck().check_not("goodbye").check_not("hello").run(traced_bad.graph) 3510*da0073e9SAndroid Build Coastguard Worker 3511*da0073e9SAndroid Build Coastguard Worker # Working example 3512*da0073e9SAndroid Build Coastguard Worker untraceable = torch.jit.script_if_tracing(untraceable) 3513*da0073e9SAndroid Build Coastguard Worker 3514*da0073e9SAndroid Build Coastguard Worker def fn2(x): 3515*da0073e9SAndroid Build Coastguard Worker return untraceable(x) 3516*da0073e9SAndroid Build Coastguard Worker 3517*da0073e9SAndroid Build Coastguard Worker with self.capture_stdout(): 3518*da0073e9SAndroid Build Coastguard Worker traced = torch.jit.trace(fn, [torch.ones(2, 2)]) 3519*da0073e9SAndroid Build Coastguard Worker 3520*da0073e9SAndroid Build Coastguard Worker FileCheck().check("goodbye").run(traced.graph) 3521*da0073e9SAndroid Build Coastguard Worker 3522*da0073e9SAndroid Build Coastguard Worker def foo(x: int): 3523*da0073e9SAndroid Build Coastguard Worker return x + 1 3524*da0073e9SAndroid Build Coastguard Worker 3525*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_if_tracing 3526*da0073e9SAndroid Build Coastguard Worker def fee(x: int = 2): 3527*da0073e9SAndroid Build Coastguard Worker return foo(1) + x 3528*da0073e9SAndroid Build Coastguard Worker 3529*da0073e9SAndroid Build Coastguard Worker # test directly compiling function 3530*da0073e9SAndroid Build Coastguard Worker fee_compiled = torch.jit.script(fee) 3531*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fee_compiled(), fee()) 3532*da0073e9SAndroid Build Coastguard Worker 3533*da0073e9SAndroid Build Coastguard Worker # test compiling it within another function 3534*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 3535*da0073e9SAndroid Build Coastguard Worker def hum(): 3536*da0073e9SAndroid Build Coastguard Worker return fee(x=3) 3537*da0073e9SAndroid Build Coastguard Worker 3538*da0073e9SAndroid Build Coastguard Worker self.assertEqual(hum(), 5) 3539*da0073e9SAndroid Build Coastguard Worker 3540*da0073e9SAndroid Build Coastguard Worker def test_big_int_literals(self): 3541*da0073e9SAndroid Build Coastguard Worker def ok(): 3542*da0073e9SAndroid Build Coastguard Worker # signed 64 bit max 3543*da0073e9SAndroid Build Coastguard Worker a = 9223372036854775807 3544*da0073e9SAndroid Build Coastguard Worker return a 3545*da0073e9SAndroid Build Coastguard Worker 3546*da0073e9SAndroid Build Coastguard Worker def toobig(): 3547*da0073e9SAndroid Build Coastguard Worker a = 9223372036854775808 3548*da0073e9SAndroid Build Coastguard Worker return a 3549*da0073e9SAndroid Build Coastguard Worker 3550*da0073e9SAndroid Build Coastguard Worker def waytoobig(): 3551*da0073e9SAndroid Build Coastguard Worker a = 99999999999999999999 3552*da0073e9SAndroid Build Coastguard Worker return a 3553*da0073e9SAndroid Build Coastguard Worker 3554*da0073e9SAndroid Build Coastguard Worker self.checkScript(ok, []) 3555*da0073e9SAndroid Build Coastguard Worker 3556*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "out of range"): 3557*da0073e9SAndroid Build Coastguard Worker torch.jit.script(toobig) 3558*da0073e9SAndroid Build Coastguard Worker 3559*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "out of range"): 3560*da0073e9SAndroid Build Coastguard Worker torch.jit.script(waytoobig) 3561*da0073e9SAndroid Build Coastguard Worker 3562*da0073e9SAndroid Build Coastguard Worker def test_hex_literals(self): 3563*da0073e9SAndroid Build Coastguard Worker def test1(): 3564*da0073e9SAndroid Build Coastguard Worker return 0xaaaaaa 3565*da0073e9SAndroid Build Coastguard Worker 3566*da0073e9SAndroid Build Coastguard Worker def test2(): 3567*da0073e9SAndroid Build Coastguard Worker return 0xaaaaaa 3568*da0073e9SAndroid Build Coastguard Worker 3569*da0073e9SAndroid Build Coastguard Worker def test3(): 3570*da0073e9SAndroid Build Coastguard Worker return -0xaaaaaa 3571*da0073e9SAndroid Build Coastguard Worker 3572*da0073e9SAndroid Build Coastguard Worker self.checkScript(test1, []) 3573*da0073e9SAndroid Build Coastguard Worker self.checkScript(test2, []) 3574*da0073e9SAndroid Build Coastguard Worker self.checkScript(test3, []) 3575*da0073e9SAndroid Build Coastguard Worker 3576*da0073e9SAndroid Build Coastguard Worker def ok(): 3577*da0073e9SAndroid Build Coastguard Worker a = 0x7FFFFFFFFFFFFFFF 3578*da0073e9SAndroid Build Coastguard Worker return a 3579*da0073e9SAndroid Build Coastguard Worker 3580*da0073e9SAndroid Build Coastguard Worker def toobig(): 3581*da0073e9SAndroid Build Coastguard Worker a = 0xFFFFFFFFFFFFFFFF 3582*da0073e9SAndroid Build Coastguard Worker return a 3583*da0073e9SAndroid Build Coastguard Worker 3584*da0073e9SAndroid Build Coastguard Worker def waytoobig(): 3585*da0073e9SAndroid Build Coastguard Worker a = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF 3586*da0073e9SAndroid Build Coastguard Worker return a 3587*da0073e9SAndroid Build Coastguard Worker 3588*da0073e9SAndroid Build Coastguard Worker self.checkScript(ok, []) 3589*da0073e9SAndroid Build Coastguard Worker 3590*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "out of range"): 3591*da0073e9SAndroid Build Coastguard Worker torch.jit.script(toobig) 3592*da0073e9SAndroid Build Coastguard Worker 3593*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "out of range"): 3594*da0073e9SAndroid Build Coastguard Worker torch.jit.script(waytoobig) 3595*da0073e9SAndroid Build Coastguard Worker 3596*da0073e9SAndroid Build Coastguard Worker def test_big_float_literals(self): 3597*da0073e9SAndroid Build Coastguard Worker def ok(): 3598*da0073e9SAndroid Build Coastguard Worker # Python interprets this as inf 3599*da0073e9SAndroid Build Coastguard Worker a = 1.2E400 3600*da0073e9SAndroid Build Coastguard Worker return a 3601*da0073e9SAndroid Build Coastguard Worker 3602*da0073e9SAndroid Build Coastguard Worker def check(fn): 3603*da0073e9SAndroid Build Coastguard Worker self.assertTrue(fn() == ok()) 3604*da0073e9SAndroid Build Coastguard Worker 3605*da0073e9SAndroid Build Coastguard Worker # checkScript doesn't work since assertEqual doesn't consider 3606*da0073e9SAndroid Build Coastguard Worker # `inf` == `inf` 3607*da0073e9SAndroid Build Coastguard Worker check(torch.jit.script(ok)) 3608*da0073e9SAndroid Build Coastguard Worker 3609*da0073e9SAndroid Build Coastguard Worker cu = torch.jit.CompilationUnit() 3610*da0073e9SAndroid Build Coastguard Worker cu.define(dedent(inspect.getsource(ok))) 3611*da0073e9SAndroid Build Coastguard Worker check(cu.ok) 3612*da0073e9SAndroid Build Coastguard Worker 3613*da0073e9SAndroid Build Coastguard Worker def _test_device_type(self, dest): 3614*da0073e9SAndroid Build Coastguard Worker def fn(x): 3615*da0073e9SAndroid Build Coastguard Worker # type: (Device) -> Tuple[str, Optional[int]] 3616*da0073e9SAndroid Build Coastguard Worker return x.type, x.index 3617*da0073e9SAndroid Build Coastguard Worker 3618*da0073e9SAndroid Build Coastguard Worker device = torch.ones(2).to(dest).device 3619*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, [device]) 3620*da0073e9SAndroid Build Coastguard Worker 3621*da0073e9SAndroid Build Coastguard Worker def test_device_type(self): 3622*da0073e9SAndroid Build Coastguard Worker self._test_device_type('cpu') 3623*da0073e9SAndroid Build Coastguard Worker 3624*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not RUN_CUDA, "Requires CUDA") 3625*da0073e9SAndroid Build Coastguard Worker def test_device_type_cuda(self): 3626*da0073e9SAndroid Build Coastguard Worker self._test_device_type('cuda') 3627*da0073e9SAndroid Build Coastguard Worker 3628*da0073e9SAndroid Build Coastguard Worker def test_string_device_implicit_conversion(self): 3629*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 3630*da0073e9SAndroid Build Coastguard Worker def fn(x: torch.device): 3631*da0073e9SAndroid Build Coastguard Worker return x 3632*da0073e9SAndroid Build Coastguard Worker 3633*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn("cpu"), torch.device("cpu")) 3634*da0073e9SAndroid Build Coastguard Worker 3635*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Expected one of"): 3636*da0073e9SAndroid Build Coastguard Worker fn("invalid_device") 3637*da0073e9SAndroid Build Coastguard Worker 3638*da0073e9SAndroid Build Coastguard Worker def test_eval_python(self): 3639*da0073e9SAndroid Build Coastguard Worker def _test(m): 3640*da0073e9SAndroid Build Coastguard Worker self.assertTrue(m(torch.ones(2, 2))) 3641*da0073e9SAndroid Build Coastguard Worker self.assertTrue(m.training) 3642*da0073e9SAndroid Build Coastguard Worker self.assertTrue(m._c.getattr('training')) 3643*da0073e9SAndroid Build Coastguard Worker 3644*da0073e9SAndroid Build Coastguard Worker m.eval() 3645*da0073e9SAndroid Build Coastguard Worker 3646*da0073e9SAndroid Build Coastguard Worker self.assertFalse(m.training) 3647*da0073e9SAndroid Build Coastguard Worker self.assertFalse(m._c.getattr('training')) 3648*da0073e9SAndroid Build Coastguard Worker self.assertFalse(m(torch.ones(2, 2))) 3649*da0073e9SAndroid Build Coastguard Worker 3650*da0073e9SAndroid Build Coastguard Worker buffer = io.BytesIO() 3651*da0073e9SAndroid Build Coastguard Worker torch.jit.save(m, buffer) 3652*da0073e9SAndroid Build Coastguard Worker buffer.seek(0) 3653*da0073e9SAndroid Build Coastguard Worker 3654*da0073e9SAndroid Build Coastguard Worker loaded = torch.jit.load(buffer) 3655*da0073e9SAndroid Build Coastguard Worker 3656*da0073e9SAndroid Build Coastguard Worker self.assertFalse(loaded.training) 3657*da0073e9SAndroid Build Coastguard Worker self.assertFalse(loaded._c.getattr('training')) 3658*da0073e9SAndroid Build Coastguard Worker 3659*da0073e9SAndroid Build Coastguard Worker class M(nn.Module): 3660*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 3661*da0073e9SAndroid Build Coastguard Worker return self.training 3662*da0073e9SAndroid Build Coastguard Worker 3663*da0073e9SAndroid Build Coastguard Worker class OldM(torch.jit.ScriptModule): 3664*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 3665*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 3666*da0073e9SAndroid Build Coastguard Worker return self.training 3667*da0073e9SAndroid Build Coastguard Worker 3668*da0073e9SAndroid Build Coastguard Worker _test(torch.jit.script(M())) 3669*da0073e9SAndroid Build Coastguard Worker _test(OldM()) 3670*da0073e9SAndroid Build Coastguard Worker 3671*da0073e9SAndroid Build Coastguard Worker def test_inherit_method(self): 3672*da0073e9SAndroid Build Coastguard Worker class A(torch.jit.ScriptModule): 3673*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 3674*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 3675*da0073e9SAndroid Build Coastguard Worker return x + self.bar(x) 3676*da0073e9SAndroid Build Coastguard Worker 3677*da0073e9SAndroid Build Coastguard Worker class B(A): 3678*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 3679*da0073e9SAndroid Build Coastguard Worker def bar(self, x): 3680*da0073e9SAndroid Build Coastguard Worker return x * x 3681*da0073e9SAndroid Build Coastguard Worker 3682*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'attribute'): 3683*da0073e9SAndroid Build Coastguard Worker A() # cannot use because bar is not defined 3684*da0073e9SAndroid Build Coastguard Worker 3685*da0073e9SAndroid Build Coastguard Worker v = torch.rand(3, 4) 3686*da0073e9SAndroid Build Coastguard Worker b = B() 3687*da0073e9SAndroid Build Coastguard Worker self.assertEqual(b(v), v + v * v) 3688*da0073e9SAndroid Build Coastguard Worker 3689*da0073e9SAndroid Build Coastguard Worker class C(torch.jit.ScriptModule): 3690*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 3691*da0073e9SAndroid Build Coastguard Worker def bar(self, x): 3692*da0073e9SAndroid Build Coastguard Worker return x 3693*da0073e9SAndroid Build Coastguard Worker 3694*da0073e9SAndroid Build Coastguard Worker class D(C, B): 3695*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 3696*da0073e9SAndroid Build Coastguard Worker super().__init__() 3697*da0073e9SAndroid Build Coastguard Worker 3698*da0073e9SAndroid Build Coastguard Worker self.assertEqual(D()(v), v + v) 3699*da0073e9SAndroid Build Coastguard Worker 3700*da0073e9SAndroid Build Coastguard Worker def test_tensor_subclasses(self): 3701*da0073e9SAndroid Build Coastguard Worker def check_subclass(x, tensor): 3702*da0073e9SAndroid Build Coastguard Worker template = dedent(""" 3703*da0073e9SAndroid Build Coastguard Worker def func(input: {}) -> {}: 3704*da0073e9SAndroid Build Coastguard Worker return torch.zeros((input.shape[0], 1), dtype=input.dtype) 3705*da0073e9SAndroid Build Coastguard Worker """) 3706*da0073e9SAndroid Build Coastguard Worker 3707*da0073e9SAndroid Build Coastguard Worker self._check_code(template.format(x, x), "func", [tensor]) 3708*da0073e9SAndroid Build Coastguard Worker 3709*da0073e9SAndroid Build Coastguard Worker check_subclass("torch.LongTensor", torch.LongTensor([[1, 2], [3, 4]])) 3710*da0073e9SAndroid Build Coastguard Worker check_subclass("torch.DoubleTensor", torch.DoubleTensor([[1.2, 2.3], [3.4, 4.5]])) 3711*da0073e9SAndroid Build Coastguard Worker check_subclass("torch.IntTensor", torch.IntTensor([[1, 2], [3, 4]])) 3712*da0073e9SAndroid Build Coastguard Worker check_subclass("torch.BoolTensor", torch.BoolTensor([[False, True], [True, False]])) 3713*da0073e9SAndroid Build Coastguard Worker 3714*da0073e9SAndroid Build Coastguard Worker def check_subclass_warn(input: torch.LongTensor) -> torch.LongTensor: 3715*da0073e9SAndroid Build Coastguard Worker return torch.zeros((input.shape[0], 1), dtype=input.dtype) 3716*da0073e9SAndroid Build Coastguard Worker 3717*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as warns: 3718*da0073e9SAndroid Build Coastguard Worker scripted = torch.jit.script(check_subclass_warn) 3719*da0073e9SAndroid Build Coastguard Worker FileCheck().check("TorchScript will treat type annotations of Tensor").run(str(warns[0])) 3720*da0073e9SAndroid Build Coastguard Worker 3721*da0073e9SAndroid Build Coastguard Worker def test_first_class_module(self): 3722*da0073e9SAndroid Build Coastguard Worker class Foo(torch.jit.ScriptModule): 3723*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 3724*da0073e9SAndroid Build Coastguard Worker super().__init__() 3725*da0073e9SAndroid Build Coastguard Worker self.foo = nn.Parameter(torch.rand(3, 4)) 3726*da0073e9SAndroid Build Coastguard Worker 3727*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 3728*da0073e9SAndroid Build Coastguard Worker def forward(self, input): 3729*da0073e9SAndroid Build Coastguard Worker self.foo = input 3730*da0073e9SAndroid Build Coastguard Worker return self.foo 3731*da0073e9SAndroid Build Coastguard Worker foo = Foo() 3732*da0073e9SAndroid Build Coastguard Worker input = torch.rand(3, 4) 3733*da0073e9SAndroid Build Coastguard Worker foo.forward(input) 3734*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input, foo.foo) 3735*da0073e9SAndroid Build Coastguard Worker 3736*da0073e9SAndroid Build Coastguard Worker @_tmp_donotuse_dont_inline_everything 3737*da0073e9SAndroid Build Coastguard Worker def test_first_class_calls(self): 3738*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 3739*da0073e9SAndroid Build Coastguard Worker class Foo: 3740*da0073e9SAndroid Build Coastguard Worker def __init__(self, x): 3741*da0073e9SAndroid Build Coastguard Worker self.bar = x 3742*da0073e9SAndroid Build Coastguard Worker 3743*da0073e9SAndroid Build Coastguard Worker def stuff(self, x): 3744*da0073e9SAndroid Build Coastguard Worker return self.bar + x 3745*da0073e9SAndroid Build Coastguard Worker 3746*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 3747*da0073e9SAndroid Build Coastguard Worker def foo(x): 3748*da0073e9SAndroid Build Coastguard Worker return x * x + Foo(x).stuff(2 * x) 3749*da0073e9SAndroid Build Coastguard Worker 3750*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 3751*da0073e9SAndroid Build Coastguard Worker def bar(x): 3752*da0073e9SAndroid Build Coastguard Worker return foo(x) * foo(x) 3753*da0073e9SAndroid Build Coastguard Worker 3754*da0073e9SAndroid Build Coastguard Worker x = torch.rand(3, 4) 3755*da0073e9SAndroid Build Coastguard Worker self.assertEqual(bar(x), (x * x + 3 * x) * (x * x + 3 * x)) 3756*da0073e9SAndroid Build Coastguard Worker 3757*da0073e9SAndroid Build Coastguard Worker def test_static_methods(self): 3758*da0073e9SAndroid Build Coastguard Worker class M(nn.Module): 3759*da0073e9SAndroid Build Coastguard Worker @staticmethod 3760*da0073e9SAndroid Build Coastguard Worker def my_method(x): 3761*da0073e9SAndroid Build Coastguard Worker return x + 100 3762*da0073e9SAndroid Build Coastguard Worker 3763*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 3764*da0073e9SAndroid Build Coastguard Worker return x + M.my_method(x) 3765*da0073e9SAndroid Build Coastguard Worker 3766*da0073e9SAndroid Build Coastguard Worker class N(nn.Module): 3767*da0073e9SAndroid Build Coastguard Worker @staticmethod 3768*da0073e9SAndroid Build Coastguard Worker def my_method(x): 3769*da0073e9SAndroid Build Coastguard Worker return x * 100 3770*da0073e9SAndroid Build Coastguard Worker 3771*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 3772*da0073e9SAndroid Build Coastguard Worker return x - M.my_method(x) + N.my_method(x) 3773*da0073e9SAndroid Build Coastguard Worker 3774*da0073e9SAndroid Build Coastguard Worker self.checkModule(M(), (torch.ones(2, 2),)) 3775*da0073e9SAndroid Build Coastguard Worker 3776*da0073e9SAndroid Build Coastguard Worker self.checkModule(N(), (torch.ones(2, 2),)) 3777*da0073e9SAndroid Build Coastguard Worker 3778*da0073e9SAndroid Build Coastguard Worker def test_invalid_prefix_annotation(self): 3779*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "annotation prefix in line"): 3780*da0073e9SAndroid Build Coastguard Worker with self.capture_stdout() as captured: 3781*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 3782*da0073e9SAndroid Build Coastguard Worker def invalid_prefix_annotation1(a): 3783*da0073e9SAndroid Build Coastguard Worker #type: (Int) -> Int # noqa: E265 3784*da0073e9SAndroid Build Coastguard Worker return a + 2 3785*da0073e9SAndroid Build Coastguard Worker 3786*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "annotation prefix in line"): 3787*da0073e9SAndroid Build Coastguard Worker with self.capture_stdout() as captured: 3788*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 3789*da0073e9SAndroid Build Coastguard Worker def invalid_prefix_annotation2(a): 3790*da0073e9SAndroid Build Coastguard Worker #type : (Int) -> Int # noqa: E265 3791*da0073e9SAndroid Build Coastguard Worker return a + 2 3792*da0073e9SAndroid Build Coastguard Worker 3793*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "annotation prefix in line"): 3794*da0073e9SAndroid Build Coastguard Worker with self.capture_stdout() as captured: 3795*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 3796*da0073e9SAndroid Build Coastguard Worker def invalid_prefix_annotation3(a): 3797*da0073e9SAndroid Build Coastguard Worker # type: (Int) -> Int 3798*da0073e9SAndroid Build Coastguard Worker return a + 2 3799*da0073e9SAndroid Build Coastguard Worker 3800*da0073e9SAndroid Build Coastguard Worker def test_builtin_function_attributes(self): 3801*da0073e9SAndroid Build Coastguard Worker class Add(nn.Module): 3802*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 3803*da0073e9SAndroid Build Coastguard Worker super().__init__() 3804*da0073e9SAndroid Build Coastguard Worker self.add = torch.add 3805*da0073e9SAndroid Build Coastguard Worker 3806*da0073e9SAndroid Build Coastguard Worker def forward(self, input): 3807*da0073e9SAndroid Build Coastguard Worker return self.add(input, input) 3808*da0073e9SAndroid Build Coastguard Worker 3809*da0073e9SAndroid Build Coastguard Worker self.checkModule(Add(), [torch.randn(2, 2)]) 3810*da0073e9SAndroid Build Coastguard Worker 3811*da0073e9SAndroid Build Coastguard Worker def test_pybind_type_comparisons(self): 3812*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 3813*da0073e9SAndroid Build Coastguard Worker def f(): 3814*da0073e9SAndroid Build Coastguard Worker return None 3815*da0073e9SAndroid Build Coastguard Worker 3816*da0073e9SAndroid Build Coastguard Worker node = list(f.graph.nodes())[0] 3817*da0073e9SAndroid Build Coastguard Worker t = node.outputsAt(0).type() 3818*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(t) 3819*da0073e9SAndroid Build Coastguard Worker 3820*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_WINDOWS, 'TODO: need to fix the test case') 3821*da0073e9SAndroid Build Coastguard Worker def test_unmatched_type_annotation(self): 3822*da0073e9SAndroid Build Coastguard Worker message1 = re.escape("Number of type annotations (2) did not match the number of function parameters (1):") 3823*da0073e9SAndroid Build Coastguard Worker message2 = 'def invalid2\\(a\\):\n\\s*~+\\.*\\s+<--- HERE\n\\s+# type: \\(Int, Int\\) -> Int\n\\s+return a \\+ 2' 3824*da0073e9SAndroid Build Coastguard Worker message3 = 'def invalid4\\(a\\):\n\\s*~+\\.*\\s+<--- HERE\n\\s+# type: \\(Int, Int\\) -> Int\n\\s+return a \\+ 2' 3825*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, message1): 3826*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 3827*da0073e9SAndroid Build Coastguard Worker def invalid1(a): 3828*da0073e9SAndroid Build Coastguard Worker # type: (Int, Int) -> Int 3829*da0073e9SAndroid Build Coastguard Worker return a + 2 3830*da0073e9SAndroid Build Coastguard Worker 3831*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, message2): 3832*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 3833*da0073e9SAndroid Build Coastguard Worker def invalid2(a): 3834*da0073e9SAndroid Build Coastguard Worker # type: (Int, Int) -> Int 3835*da0073e9SAndroid Build Coastguard Worker return a + 2 3836*da0073e9SAndroid Build Coastguard Worker 3837*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, message1): 3838*da0073e9SAndroid Build Coastguard Worker def invalid3(a): 3839*da0073e9SAndroid Build Coastguard Worker # type: (Int, Int) -> Int 3840*da0073e9SAndroid Build Coastguard Worker return a + 2 3841*da0073e9SAndroid Build Coastguard Worker torch.jit.script(invalid3) 3842*da0073e9SAndroid Build Coastguard Worker 3843*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, message3): 3844*da0073e9SAndroid Build Coastguard Worker def invalid4(a): 3845*da0073e9SAndroid Build Coastguard Worker # type: (Int, Int) -> Int 3846*da0073e9SAndroid Build Coastguard Worker return a + 2 3847*da0073e9SAndroid Build Coastguard Worker torch.jit.script(invalid4) 3848*da0073e9SAndroid Build Coastguard Worker 3849*da0073e9SAndroid Build Coastguard Worker def test_calls_in_type_annotations(self): 3850*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Type annotation should not contain calls"): 3851*da0073e9SAndroid Build Coastguard Worker def spooky(a): 3852*da0073e9SAndroid Build Coastguard Worker # type: print("Hello") -> Tensor # noqa: F723 3853*da0073e9SAndroid Build Coastguard Worker return a + 2 3854*da0073e9SAndroid Build Coastguard Worker print(torch.__file__) 3855*da0073e9SAndroid Build Coastguard Worker torch.jit.annotations.get_signature(spooky, None, 1, True) 3856*da0073e9SAndroid Build Coastguard Worker 3857*da0073e9SAndroid Build Coastguard Worker def test_is_optional(self): 3858*da0073e9SAndroid Build Coastguard Worker ann = Union[List[int], List[float]] 3859*da0073e9SAndroid Build Coastguard Worker torch._jit_internal.is_optional(ann) 3860*da0073e9SAndroid Build Coastguard Worker 3861*da0073e9SAndroid Build Coastguard Worker def test_interpreter_fuzz(self): 3862*da0073e9SAndroid Build Coastguard Worker import builtins 3863*da0073e9SAndroid Build Coastguard Worker # This test generates random tree-like programs to fuzz test 3864*da0073e9SAndroid Build Coastguard Worker # that the interpreter does not have a bug in its stack manipulation 3865*da0073e9SAndroid Build Coastguard Worker # code. An assert in that code ensures individual operators are 3866*da0073e9SAndroid Build Coastguard Worker # not reordered. 3867*da0073e9SAndroid Build Coastguard Worker templates = [ 3868*da0073e9SAndroid Build Coastguard Worker "torch.rand(3, 4)", 3869*da0073e9SAndroid Build Coastguard Worker "({} + {})", 3870*da0073e9SAndroid Build Coastguard Worker "-{}", 3871*da0073e9SAndroid Build Coastguard Worker "({} * {})", 3872*da0073e9SAndroid Build Coastguard Worker "torch.tanh({})", 3873*da0073e9SAndroid Build Coastguard Worker "VAR {}", 3874*da0073e9SAndroid Build Coastguard Worker ] 3875*da0073e9SAndroid Build Coastguard Worker 3876*da0073e9SAndroid Build Coastguard Worker def gen_code(): 3877*da0073e9SAndroid Build Coastguard Worker src_lines = ['def f():'] 3878*da0073e9SAndroid Build Coastguard Worker exprs = [] 3879*da0073e9SAndroid Build Coastguard Worker n_variables = 0 3880*da0073e9SAndroid Build Coastguard Worker 3881*da0073e9SAndroid Build Coastguard Worker def get_expr(idx): 3882*da0073e9SAndroid Build Coastguard Worker elem = exprs[idx] 3883*da0073e9SAndroid Build Coastguard Worker exprs[idx] = exprs[-1] 3884*da0073e9SAndroid Build Coastguard Worker exprs.pop() 3885*da0073e9SAndroid Build Coastguard Worker return elem 3886*da0073e9SAndroid Build Coastguard Worker 3887*da0073e9SAndroid Build Coastguard Worker def select_expr_or_var(): 3888*da0073e9SAndroid Build Coastguard Worker idx = random.randrange(0, len(exprs) + n_variables) 3889*da0073e9SAndroid Build Coastguard Worker if idx < len(exprs): 3890*da0073e9SAndroid Build Coastguard Worker return get_expr(idx) 3891*da0073e9SAndroid Build Coastguard Worker else: 3892*da0073e9SAndroid Build Coastguard Worker return f'v{idx - len(exprs)}' 3893*da0073e9SAndroid Build Coastguard Worker 3894*da0073e9SAndroid Build Coastguard Worker for i in range(50): 3895*da0073e9SAndroid Build Coastguard Worker n = None 3896*da0073e9SAndroid Build Coastguard Worker while n is None or n > len(exprs) + n_variables: 3897*da0073e9SAndroid Build Coastguard Worker template = random.choice(templates) 3898*da0073e9SAndroid Build Coastguard Worker n = template.count('{}') 3899*da0073e9SAndroid Build Coastguard Worker 3900*da0073e9SAndroid Build Coastguard Worker if 'VAR' in template: 3901*da0073e9SAndroid Build Coastguard Worker src_lines.append(f' v{n_variables} = {select_expr_or_var()}') 3902*da0073e9SAndroid Build Coastguard Worker n_variables += 1 3903*da0073e9SAndroid Build Coastguard Worker else: 3904*da0073e9SAndroid Build Coastguard Worker exprs.append(template.format(*(select_expr_or_var() for _ in range(n)))) 3905*da0073e9SAndroid Build Coastguard Worker 3906*da0073e9SAndroid Build Coastguard Worker src_lines.append(' return ({})\n'.format(''.join(f'v{i},' for i in range(n_variables)))) 3907*da0073e9SAndroid Build Coastguard Worker return '\n'.join(src_lines) 3908*da0073e9SAndroid Build Coastguard Worker 3909*da0073e9SAndroid Build Coastguard Worker for i in range(100): 3910*da0073e9SAndroid Build Coastguard Worker g = {'torch': torch} 3911*da0073e9SAndroid Build Coastguard Worker code = gen_code() 3912*da0073e9SAndroid Build Coastguard Worker builtins.exec(code, g, None) 3913*da0073e9SAndroid Build Coastguard Worker cu = torch.jit.CompilationUnit(code) 3914*da0073e9SAndroid Build Coastguard Worker with freeze_rng_state(): 3915*da0073e9SAndroid Build Coastguard Worker o1 = g['f']() 3916*da0073e9SAndroid Build Coastguard Worker with freeze_rng_state(): 3917*da0073e9SAndroid Build Coastguard Worker o2 = cu.f() 3918*da0073e9SAndroid Build Coastguard Worker self.assertEqual(o1, o2) 3919*da0073e9SAndroid Build Coastguard Worker 3920*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("TorchDynamo fails with unknown reason") 3921*da0073e9SAndroid Build Coastguard Worker def test_cpp_module_iterator(self): 3922*da0073e9SAndroid Build Coastguard Worker a = nn.Module() 3923*da0073e9SAndroid Build Coastguard Worker a.name = 'a' 3924*da0073e9SAndroid Build Coastguard Worker a.p = nn.Parameter(torch.rand(3, 4)) 3925*da0073e9SAndroid Build Coastguard Worker a.foo = nn.Module() 3926*da0073e9SAndroid Build Coastguard Worker a.foo.name = 'foo' 3927*da0073e9SAndroid Build Coastguard Worker a.foo.b = nn.Buffer(torch.rand(1, 1)) 3928*da0073e9SAndroid Build Coastguard Worker a.foo.bar = nn.Module() 3929*da0073e9SAndroid Build Coastguard Worker a.foo.bar.name = 'bar' 3930*da0073e9SAndroid Build Coastguard Worker a.foo.bar.an_int = 4 3931*da0073e9SAndroid Build Coastguard Worker a.another = nn.Module() 3932*da0073e9SAndroid Build Coastguard Worker a.another.name = 'another' 3933*da0073e9SAndroid Build Coastguard Worker sa = torch.jit.script(a) 3934*da0073e9SAndroid Build Coastguard Worker result = torch._C._jit_debug_module_iterators(sa._c) 3935*da0073e9SAndroid Build Coastguard Worker 3936*da0073e9SAndroid Build Coastguard Worker def replace(e): 3937*da0073e9SAndroid Build Coastguard Worker if e is a.p: 3938*da0073e9SAndroid Build Coastguard Worker return 'P' 3939*da0073e9SAndroid Build Coastguard Worker elif e is a.foo.b: 3940*da0073e9SAndroid Build Coastguard Worker return 'B' 3941*da0073e9SAndroid Build Coastguard Worker elif isinstance(e, torch._C.ScriptModule): 3942*da0073e9SAndroid Build Coastguard Worker return e.getattr('name') 3943*da0073e9SAndroid Build Coastguard Worker 3944*da0073e9SAndroid Build Coastguard Worker return e 3945*da0073e9SAndroid Build Coastguard Worker for v in result.values(): 3946*da0073e9SAndroid Build Coastguard Worker for i in range(len(v)): 3947*da0073e9SAndroid Build Coastguard Worker if isinstance(v[i], tuple): 3948*da0073e9SAndroid Build Coastguard Worker n, v2 = v[i] 3949*da0073e9SAndroid Build Coastguard Worker v[i] = (n, replace(v2)) 3950*da0073e9SAndroid Build Coastguard Worker else: 3951*da0073e9SAndroid Build Coastguard Worker v[i] = replace(v[i]) 3952*da0073e9SAndroid Build Coastguard Worker # module type creation is not deterministic, so we have to sort 3953*da0073e9SAndroid Build Coastguard Worker # the result 3954*da0073e9SAndroid Build Coastguard Worker v.sort() 3955*da0073e9SAndroid Build Coastguard Worker expected = {'buffers': [], 3956*da0073e9SAndroid Build Coastguard Worker 'buffers_r': ['B'], 3957*da0073e9SAndroid Build Coastguard Worker 'children': ['another', 'foo'], 3958*da0073e9SAndroid Build Coastguard Worker 'modules': ['a', 'another', 'bar', 'foo'], 3959*da0073e9SAndroid Build Coastguard Worker 'named_attributes': [('_is_full_backward_hook', None), 3960*da0073e9SAndroid Build Coastguard Worker ('another', 'another'), 3961*da0073e9SAndroid Build Coastguard Worker ('foo', 'foo'), 3962*da0073e9SAndroid Build Coastguard Worker ('name', 'a'), 3963*da0073e9SAndroid Build Coastguard Worker ('p', 'P'), 3964*da0073e9SAndroid Build Coastguard Worker ('training', True)], 3965*da0073e9SAndroid Build Coastguard Worker 'named_attributes_r': [('_is_full_backward_hook', None), 3966*da0073e9SAndroid Build Coastguard Worker ('another', 'another'), 3967*da0073e9SAndroid Build Coastguard Worker ('another._is_full_backward_hook', None), 3968*da0073e9SAndroid Build Coastguard Worker ('another.name', 'another'), 3969*da0073e9SAndroid Build Coastguard Worker ('another.training', True), 3970*da0073e9SAndroid Build Coastguard Worker ('foo', 'foo'), 3971*da0073e9SAndroid Build Coastguard Worker ('foo._is_full_backward_hook', None), 3972*da0073e9SAndroid Build Coastguard Worker ('foo.b', 'B'), 3973*da0073e9SAndroid Build Coastguard Worker ('foo.bar', 'bar'), 3974*da0073e9SAndroid Build Coastguard Worker ('foo.bar._is_full_backward_hook', None), 3975*da0073e9SAndroid Build Coastguard Worker ('foo.bar.an_int', 4), 3976*da0073e9SAndroid Build Coastguard Worker ('foo.bar.name', 'bar'), 3977*da0073e9SAndroid Build Coastguard Worker ('foo.bar.training', True), 3978*da0073e9SAndroid Build Coastguard Worker ('foo.name', 'foo'), 3979*da0073e9SAndroid Build Coastguard Worker ('foo.training', True), 3980*da0073e9SAndroid Build Coastguard Worker ('name', 'a'), 3981*da0073e9SAndroid Build Coastguard Worker ('p', 'P'), 3982*da0073e9SAndroid Build Coastguard Worker ('training', True)], 3983*da0073e9SAndroid Build Coastguard Worker 'named_buffers': [], 3984*da0073e9SAndroid Build Coastguard Worker 'named_buffers_r': [('foo.b', 'B')], 3985*da0073e9SAndroid Build Coastguard Worker 'named_children': [('another', 'another'), ('foo', 'foo')], 3986*da0073e9SAndroid Build Coastguard Worker 'named_modules': [('', 'a'), 3987*da0073e9SAndroid Build Coastguard Worker ('another', 'another'), 3988*da0073e9SAndroid Build Coastguard Worker ('foo', 'foo'), 3989*da0073e9SAndroid Build Coastguard Worker ('foo.bar', 'bar')], 3990*da0073e9SAndroid Build Coastguard Worker 'named_parameters': [('p', 'P')], 3991*da0073e9SAndroid Build Coastguard Worker 'named_parameters_r': [('p', 'P')], 3992*da0073e9SAndroid Build Coastguard Worker 'parameters': ['P'], 3993*da0073e9SAndroid Build Coastguard Worker 'parameters_r': ['P']} 3994*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, result) 3995*da0073e9SAndroid Build Coastguard Worker 3996*da0073e9SAndroid Build Coastguard Worker def test_parameter_order(self): 3997*da0073e9SAndroid Build Coastguard Worker m = nn.Module() 3998*da0073e9SAndroid Build Coastguard Worker for i, name in enumerate(string.ascii_letters): 3999*da0073e9SAndroid Build Coastguard Worker setattr(m, name, nn.Parameter(torch.tensor([float(i)]))) 4000*da0073e9SAndroid Build Coastguard Worker ms = torch.jit.script(m) 4001*da0073e9SAndroid Build Coastguard Worker print(torch.cat(list(m.parameters()))) 4002*da0073e9SAndroid Build Coastguard Worker print(torch.cat(list(ms.parameters()))) 4003*da0073e9SAndroid Build Coastguard Worker self.assertEqual(list(m.parameters()), list(ms.parameters())) 4004*da0073e9SAndroid Build Coastguard Worker 4005*da0073e9SAndroid Build Coastguard Worker def test_python_op_builtins(self): 4006*da0073e9SAndroid Build Coastguard Worker @torch.jit.unused 4007*da0073e9SAndroid Build Coastguard Worker def fn(x): 4008*da0073e9SAndroid Build Coastguard Worker # type: (List[int]) -> int 4009*da0073e9SAndroid Build Coastguard Worker return sum(x) 4010*da0073e9SAndroid Build Coastguard Worker 4011*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 4012*da0073e9SAndroid Build Coastguard Worker def script_fn(x): 4013*da0073e9SAndroid Build Coastguard Worker # type: (List[int]) -> int 4014*da0073e9SAndroid Build Coastguard Worker return fn(x) 4015*da0073e9SAndroid Build Coastguard Worker 4016*da0073e9SAndroid Build Coastguard Worker def test_submodule_twice(self): 4017*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 4018*da0073e9SAndroid Build Coastguard Worker def foo(x): 4019*da0073e9SAndroid Build Coastguard Worker return x * x 4020*da0073e9SAndroid Build Coastguard Worker 4021*da0073e9SAndroid Build Coastguard Worker class What(torch.jit.ScriptModule): 4022*da0073e9SAndroid Build Coastguard Worker def __init__(self, x): 4023*da0073e9SAndroid Build Coastguard Worker super().__init__() 4024*da0073e9SAndroid Build Coastguard Worker self.foo = x 4025*da0073e9SAndroid Build Coastguard Worker a = What(foo) 4026*da0073e9SAndroid Build Coastguard Worker c = What(foo) 4027*da0073e9SAndroid Build Coastguard Worker 4028*da0073e9SAndroid Build Coastguard Worker def test_training_param(self): 4029*da0073e9SAndroid Build Coastguard Worker class What(torch.jit.ScriptModule): 4030*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 4031*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 4032*da0073e9SAndroid Build Coastguard Worker # type: (int) -> int 4033*da0073e9SAndroid Build Coastguard Worker if self.training: 4034*da0073e9SAndroid Build Coastguard Worker r = x 4035*da0073e9SAndroid Build Coastguard Worker else: 4036*da0073e9SAndroid Build Coastguard Worker r = x + 4 4037*da0073e9SAndroid Build Coastguard Worker # check double use of training 4038*da0073e9SAndroid Build Coastguard Worker if self.training: 4039*da0073e9SAndroid Build Coastguard Worker r = r + 1 4040*da0073e9SAndroid Build Coastguard Worker return r 4041*da0073e9SAndroid Build Coastguard Worker 4042*da0073e9SAndroid Build Coastguard Worker w = What() 4043*da0073e9SAndroid Build Coastguard Worker self.assertEqual(4, w(3)) 4044*da0073e9SAndroid Build Coastguard Worker w.train(False) 4045*da0073e9SAndroid Build Coastguard Worker self.assertEqual(7, w(3)) 4046*da0073e9SAndroid Build Coastguard Worker self.assertFalse("training" in w.state_dict()) 4047*da0073e9SAndroid Build Coastguard Worker 4048*da0073e9SAndroid Build Coastguard Worker def test_class_as_attribute(self): 4049*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 4050*da0073e9SAndroid Build Coastguard Worker class Foo321: 4051*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 4052*da0073e9SAndroid Build Coastguard Worker self.x = 3 4053*da0073e9SAndroid Build Coastguard Worker 4054*da0073e9SAndroid Build Coastguard Worker class FooBar1234(torch.nn.Module): 4055*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 4056*da0073e9SAndroid Build Coastguard Worker super().__init__() 4057*da0073e9SAndroid Build Coastguard Worker self.f = Foo321() 4058*da0073e9SAndroid Build Coastguard Worker 4059*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 4060*da0073e9SAndroid Build Coastguard Worker return x + self.f.x 4061*da0073e9SAndroid Build Coastguard Worker 4062*da0073e9SAndroid Build Coastguard Worker scripted = torch.jit.script(FooBar1234()) 4063*da0073e9SAndroid Build Coastguard Worker eic = self.getExportImportCopy(scripted) 4064*da0073e9SAndroid Build Coastguard Worker x = torch.rand(3, 4) 4065*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scripted(x), eic(x)) 4066*da0073e9SAndroid Build Coastguard Worker 4067*da0073e9SAndroid Build Coastguard Worker def test_module_str(self): 4068*da0073e9SAndroid Build Coastguard Worker class Foo(torch.nn.Module): 4069*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 4070*da0073e9SAndroid Build Coastguard Worker return torch.relu(x) 4071*da0073e9SAndroid Build Coastguard Worker 4072*da0073e9SAndroid Build Coastguard Worker f = torch.jit.script(Foo()) 4073*da0073e9SAndroid Build Coastguard Worker 4074*da0073e9SAndroid Build Coastguard Worker str_f = str(f._c) 4075*da0073e9SAndroid Build Coastguard Worker self.assertTrue(str_f.startswith('ScriptObject')) 4076*da0073e9SAndroid Build Coastguard Worker self.assertTrue('__torch__.' in str_f) 4077*da0073e9SAndroid Build Coastguard Worker self.assertTrue('.Foo' in str_f) 4078*da0073e9SAndroid Build Coastguard Worker 4079*da0073e9SAndroid Build Coastguard Worker def test_jitter_bug(self): 4080*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 4081*da0073e9SAndroid Build Coastguard Worker def fn2(input, kernel_size): 4082*da0073e9SAndroid Build Coastguard Worker # type: (Tensor, List[int]) -> Tensor 4083*da0073e9SAndroid Build Coastguard Worker if kernel_size[0] > 1: 4084*da0073e9SAndroid Build Coastguard Worker _stride = [2] 4085*da0073e9SAndroid Build Coastguard Worker else: 4086*da0073e9SAndroid Build Coastguard Worker _stride = kernel_size 4087*da0073e9SAndroid Build Coastguard Worker print(_stride, kernel_size) 4088*da0073e9SAndroid Build Coastguard Worker return input 4089*da0073e9SAndroid Build Coastguard Worker 4090*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 4091*da0073e9SAndroid Build Coastguard Worker def fn(input): 4092*da0073e9SAndroid Build Coastguard Worker # type: (Tensor) -> Tensor 4093*da0073e9SAndroid Build Coastguard Worker return fn2(input, [1]) 4094*da0073e9SAndroid Build Coastguard Worker 4095*da0073e9SAndroid Build Coastguard Worker def test_parser_kwargonly(self): 4096*da0073e9SAndroid Build Coastguard Worker cu = torch.jit.CompilationUnit(''' 4097*da0073e9SAndroid Build Coastguard Worker def foo(x, *, y) -> Tuple[Tensor, Tensor]: 4098*da0073e9SAndroid Build Coastguard Worker return x, x 4099*da0073e9SAndroid Build Coastguard Worker def bar(x): 4100*da0073e9SAndroid Build Coastguard Worker return foo(x, y=x) 4101*da0073e9SAndroid Build Coastguard Worker ''') 4102*da0073e9SAndroid Build Coastguard Worker self.assertTrue('*' in str(cu.foo.schema)) 4103*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "not provided"): 4104*da0073e9SAndroid Build Coastguard Worker torch.jit.CompilationUnit(''' 4105*da0073e9SAndroid Build Coastguard Worker def foo(x, *, y) -> Tuple[Tensor, Tensor]: 4106*da0073e9SAndroid Build Coastguard Worker return x, x 4107*da0073e9SAndroid Build Coastguard Worker def bar(x): 4108*da0073e9SAndroid Build Coastguard Worker return foo(x, x) 4109*da0073e9SAndroid Build Coastguard Worker ''') 4110*da0073e9SAndroid Build Coastguard Worker 4111*da0073e9SAndroid Build Coastguard Worker def test_annoying_doubles(self): 4112*da0073e9SAndroid Build Coastguard Worker mod = types.ModuleType("temp") 4113*da0073e9SAndroid Build Coastguard Worker mod.inf = float("inf") 4114*da0073e9SAndroid Build Coastguard Worker mod.ninf = float("-inf") 4115*da0073e9SAndroid Build Coastguard Worker mod.nan = float("nan") 4116*da0073e9SAndroid Build Coastguard Worker 4117*da0073e9SAndroid Build Coastguard Worker with torch._jit_internal._disable_emit_hooks(): 4118*da0073e9SAndroid Build Coastguard Worker class Foo(torch.jit.ScriptModule): 4119*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 4120*da0073e9SAndroid Build Coastguard Worker def forward(self): 4121*da0073e9SAndroid Build Coastguard Worker return math.pi, 0.1, mod.inf, mod.ninf, 2.225073858507201e-308, mod.nan 4122*da0073e9SAndroid Build Coastguard Worker 4123*da0073e9SAndroid Build Coastguard Worker foo = Foo() 4124*da0073e9SAndroid Build Coastguard Worker buffer = io.BytesIO() 4125*da0073e9SAndroid Build Coastguard Worker torch.jit.save(foo, buffer) 4126*da0073e9SAndroid Build Coastguard Worker 4127*da0073e9SAndroid Build Coastguard Worker buffer.seek(0) 4128*da0073e9SAndroid Build Coastguard Worker foo_loaded = torch.jit.load(buffer) 4129*da0073e9SAndroid Build Coastguard Worker 4130*da0073e9SAndroid Build Coastguard Worker r = foo() 4131*da0073e9SAndroid Build Coastguard Worker r2 = foo_loaded() 4132*da0073e9SAndroid Build Coastguard Worker # use precise assert, we are checking floating point details 4133*da0073e9SAndroid Build Coastguard Worker self.assertTrue(r[:-1] == r2[:-1]) 4134*da0073e9SAndroid Build Coastguard Worker self.assertTrue(math.isnan(r[-1]) and math.isnan(r2[-1])) 4135*da0073e9SAndroid Build Coastguard Worker 4136*da0073e9SAndroid Build Coastguard Worker def test_type_annotate(self): 4137*da0073e9SAndroid Build Coastguard Worker 4138*da0073e9SAndroid Build Coastguard Worker def foo(a): 4139*da0073e9SAndroid Build Coastguard Worker return torch.jit.annotate(torch.Tensor, a) 4140*da0073e9SAndroid Build Coastguard Worker 4141*da0073e9SAndroid Build Coastguard Worker self.checkScript(foo, (torch.rand(3),)) 4142*da0073e9SAndroid Build Coastguard Worker 4143*da0073e9SAndroid Build Coastguard Worker def bar(): 4144*da0073e9SAndroid Build Coastguard Worker a = torch.jit.annotate(List[int], []) 4145*da0073e9SAndroid Build Coastguard Worker for _ in range(10): 4146*da0073e9SAndroid Build Coastguard Worker a.append(4) 4147*da0073e9SAndroid Build Coastguard Worker return a 4148*da0073e9SAndroid Build Coastguard Worker 4149*da0073e9SAndroid Build Coastguard Worker self.checkScript(bar, ()) 4150*da0073e9SAndroid Build Coastguard Worker 4151*da0073e9SAndroid Build Coastguard Worker def baz(a): 4152*da0073e9SAndroid Build Coastguard Worker return torch.jit.annotate(float, a) 4153*da0073e9SAndroid Build Coastguard Worker self.checkScript(baz, (torch.rand(()),)) 4154*da0073e9SAndroid Build Coastguard Worker 4155*da0073e9SAndroid Build Coastguard Worker # test annotate none types 4156*da0073e9SAndroid Build Coastguard Worker def annotate_none(): 4157*da0073e9SAndroid Build Coastguard Worker return torch.jit.annotate(Optional[torch.Tensor], None) 4158*da0073e9SAndroid Build Coastguard Worker 4159*da0073e9SAndroid Build Coastguard Worker self.checkScript(annotate_none, ()) 4160*da0073e9SAndroid Build Coastguard Worker 4161*da0073e9SAndroid Build Coastguard Worker 4162*da0073e9SAndroid Build Coastguard Worker def test_robust_op_resolution(self): 4163*da0073e9SAndroid Build Coastguard Worker neg = torch.add # misleading name to make sure we resolve by function 4164*da0073e9SAndroid Build Coastguard Worker 4165*da0073e9SAndroid Build Coastguard Worker def stuff(x): 4166*da0073e9SAndroid Build Coastguard Worker return neg(x, x) 4167*da0073e9SAndroid Build Coastguard Worker 4168*da0073e9SAndroid Build Coastguard Worker a = (torch.rand(3),) 4169*da0073e9SAndroid Build Coastguard Worker self.checkScript(stuff, a) 4170*da0073e9SAndroid Build Coastguard Worker 4171*da0073e9SAndroid Build Coastguard Worker def test_nested_aug_assign(self): 4172*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 4173*da0073e9SAndroid Build Coastguard Worker class SomeClass: 4174*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 4175*da0073e9SAndroid Build Coastguard Worker self.num = 99 4176*da0073e9SAndroid Build Coastguard Worker 4177*da0073e9SAndroid Build Coastguard Worker def __iadd__(self, x): 4178*da0073e9SAndroid Build Coastguard Worker # type: (int) 4179*da0073e9SAndroid Build Coastguard Worker self.num += x 4180*da0073e9SAndroid Build Coastguard Worker return self 4181*da0073e9SAndroid Build Coastguard Worker 4182*da0073e9SAndroid Build Coastguard Worker def __eq__(self, other): 4183*da0073e9SAndroid Build Coastguard Worker # type: (SomeClass) -> bool 4184*da0073e9SAndroid Build Coastguard Worker return self.num == other.num 4185*da0073e9SAndroid Build Coastguard Worker 4186*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 4187*da0073e9SAndroid Build Coastguard Worker class SomeOutOfPlaceClass: 4188*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 4189*da0073e9SAndroid Build Coastguard Worker self.num = 99 4190*da0073e9SAndroid Build Coastguard Worker 4191*da0073e9SAndroid Build Coastguard Worker def __add__(self, x): 4192*da0073e9SAndroid Build Coastguard Worker # type: (int) 4193*da0073e9SAndroid Build Coastguard Worker self.num = x 4194*da0073e9SAndroid Build Coastguard Worker return self 4195*da0073e9SAndroid Build Coastguard Worker 4196*da0073e9SAndroid Build Coastguard Worker def __eq__(self, other): 4197*da0073e9SAndroid Build Coastguard Worker # type: (SomeClass) -> bool 4198*da0073e9SAndroid Build Coastguard Worker return self.num == other.num 4199*da0073e9SAndroid Build Coastguard Worker 4200*da0073e9SAndroid Build Coastguard Worker class Child(nn.Module): 4201*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 4202*da0073e9SAndroid Build Coastguard Worker super().__init__() 4203*da0073e9SAndroid Build Coastguard Worker self.x = 2 4204*da0073e9SAndroid Build Coastguard Worker self.o = SomeClass() 4205*da0073e9SAndroid Build Coastguard Worker self.oop = SomeOutOfPlaceClass() 4206*da0073e9SAndroid Build Coastguard Worker self.list = [1, 2, 3] 4207*da0073e9SAndroid Build Coastguard Worker 4208*da0073e9SAndroid Build Coastguard Worker class A(nn.Module): 4209*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 4210*da0073e9SAndroid Build Coastguard Worker super().__init__() 4211*da0073e9SAndroid Build Coastguard Worker self.child = Child() 4212*da0073e9SAndroid Build Coastguard Worker 4213*da0073e9SAndroid Build Coastguard Worker def forward(self): 4214*da0073e9SAndroid Build Coastguard Worker self.child.x += 1 4215*da0073e9SAndroid Build Coastguard Worker self.child.o += 5 4216*da0073e9SAndroid Build Coastguard Worker self.child.oop += 5 4217*da0073e9SAndroid Build Coastguard Worker some_list = [1, 2] 4218*da0073e9SAndroid Build Coastguard Worker self.child.list += some_list 4219*da0073e9SAndroid Build Coastguard Worker self.child.list *= 2 4220*da0073e9SAndroid Build Coastguard Worker return self.child.x, self.child.o, self.child.list, self.child.oop 4221*da0073e9SAndroid Build Coastguard Worker 4222*da0073e9SAndroid Build Coastguard Worker a = A() 4223*da0073e9SAndroid Build Coastguard Worker sa = torch.jit.script(A()) 4224*da0073e9SAndroid Build Coastguard Worker eager_result = a() 4225*da0073e9SAndroid Build Coastguard Worker script_result = sa() 4226*da0073e9SAndroid Build Coastguard Worker self.assertEqual(eager_result, script_result) 4227*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.child.x, sa.child.x) 4228*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.child.o, sa.child.o) 4229*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.child.list, sa.child.list) 4230*da0073e9SAndroid Build Coastguard Worker 4231*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 4232*da0073e9SAndroid Build Coastguard Worker class SomeNonAddableClass: 4233*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 4234*da0073e9SAndroid Build Coastguard Worker self.num = 99 4235*da0073e9SAndroid Build Coastguard Worker 4236*da0073e9SAndroid Build Coastguard Worker def __eq__(self, other): 4237*da0073e9SAndroid Build Coastguard Worker # type: (SomeClass) -> bool 4238*da0073e9SAndroid Build Coastguard Worker return self.num == other.num 4239*da0073e9SAndroid Build Coastguard Worker 4240*da0073e9SAndroid Build Coastguard Worker # with self.assertRaisesRegex(RuntimeError, "") 4241*da0073e9SAndroid Build Coastguard Worker class A(nn.Module): 4242*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 4243*da0073e9SAndroid Build Coastguard Worker super().__init__() 4244*da0073e9SAndroid Build Coastguard Worker self.x = SomeNonAddableClass() 4245*da0073e9SAndroid Build Coastguard Worker 4246*da0073e9SAndroid Build Coastguard Worker def forward(self): 4247*da0073e9SAndroid Build Coastguard Worker self.x += SomeNonAddableClass() 4248*da0073e9SAndroid Build Coastguard Worker return self.x 4249*da0073e9SAndroid Build Coastguard Worker 4250*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Cannot emit inplace op"): 4251*da0073e9SAndroid Build Coastguard Worker torch.jit.script(A()) 4252*da0073e9SAndroid Build Coastguard Worker 4253*da0073e9SAndroid Build Coastguard Worker def test_var_aug_assign(self): 4254*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 4255*da0073e9SAndroid Build Coastguard Worker class SomeNonAddableClass: 4256*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 4257*da0073e9SAndroid Build Coastguard Worker self.num = 99 4258*da0073e9SAndroid Build Coastguard Worker 4259*da0073e9SAndroid Build Coastguard Worker def __eq__(self, other): 4260*da0073e9SAndroid Build Coastguard Worker # type: (SomeNonAddableClass) -> bool 4261*da0073e9SAndroid Build Coastguard Worker return self.num == other.num 4262*da0073e9SAndroid Build Coastguard Worker 4263*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Cannot emit inplace op"): 4264*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 4265*da0073e9SAndroid Build Coastguard Worker def fn(): 4266*da0073e9SAndroid Build Coastguard Worker a = SomeNonAddableClass() 4267*da0073e9SAndroid Build Coastguard Worker a += SomeNonAddableClass() 4268*da0073e9SAndroid Build Coastguard Worker return a 4269*da0073e9SAndroid Build Coastguard Worker 4270*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 4271*da0073e9SAndroid Build Coastguard Worker class SomeClass: 4272*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 4273*da0073e9SAndroid Build Coastguard Worker self.num = 99 4274*da0073e9SAndroid Build Coastguard Worker 4275*da0073e9SAndroid Build Coastguard Worker def __iadd__(self, x): 4276*da0073e9SAndroid Build Coastguard Worker # type: (int) 4277*da0073e9SAndroid Build Coastguard Worker self.num += x 4278*da0073e9SAndroid Build Coastguard Worker return self 4279*da0073e9SAndroid Build Coastguard Worker 4280*da0073e9SAndroid Build Coastguard Worker def __eq__(self, other): 4281*da0073e9SAndroid Build Coastguard Worker # type: (SomeClass) -> bool 4282*da0073e9SAndroid Build Coastguard Worker return self.num == other.num 4283*da0073e9SAndroid Build Coastguard Worker 4284*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 4285*da0073e9SAndroid Build Coastguard Worker class SomeOutOfPlaceClass: 4286*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 4287*da0073e9SAndroid Build Coastguard Worker self.num = 99 4288*da0073e9SAndroid Build Coastguard Worker 4289*da0073e9SAndroid Build Coastguard Worker def __add__(self, x): 4290*da0073e9SAndroid Build Coastguard Worker # type: (int) 4291*da0073e9SAndroid Build Coastguard Worker self.num = x 4292*da0073e9SAndroid Build Coastguard Worker return self 4293*da0073e9SAndroid Build Coastguard Worker 4294*da0073e9SAndroid Build Coastguard Worker def __eq__(self, other): 4295*da0073e9SAndroid Build Coastguard Worker # type: (SomeClass) -> bool 4296*da0073e9SAndroid Build Coastguard Worker return self.num == other.num 4297*da0073e9SAndroid Build Coastguard Worker 4298*da0073e9SAndroid Build Coastguard Worker def fn2(): 4299*da0073e9SAndroid Build Coastguard Worker a = SomeClass() 4300*da0073e9SAndroid Build Coastguard Worker a_copy = a 4301*da0073e9SAndroid Build Coastguard Worker a += 20 4302*da0073e9SAndroid Build Coastguard Worker assert a is a_copy 4303*da0073e9SAndroid Build Coastguard Worker b = SomeOutOfPlaceClass() 4304*da0073e9SAndroid Build Coastguard Worker b_copy = b 4305*da0073e9SAndroid Build Coastguard Worker b += 99 4306*da0073e9SAndroid Build Coastguard Worker assert b is b_copy 4307*da0073e9SAndroid Build Coastguard Worker c = [1, 2, 3] 4308*da0073e9SAndroid Build Coastguard Worker c_copy = c 4309*da0073e9SAndroid Build Coastguard Worker c *= 2 4310*da0073e9SAndroid Build Coastguard Worker assert c is c_copy 4311*da0073e9SAndroid Build Coastguard Worker c += [4, 5, 6] 4312*da0073e9SAndroid Build Coastguard Worker d = torch.ones(2, 2) 4313*da0073e9SAndroid Build Coastguard Worker d_copy = d 4314*da0073e9SAndroid Build Coastguard Worker d += torch.ones(2, 2) 4315*da0073e9SAndroid Build Coastguard Worker assert d is d_copy 4316*da0073e9SAndroid Build Coastguard Worker return a, b, c, d 4317*da0073e9SAndroid Build Coastguard Worker 4318*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn2, []) 4319*da0073e9SAndroid Build Coastguard Worker 4320*da0073e9SAndroid Build Coastguard Worker def test_nested_list_construct(self): 4321*da0073e9SAndroid Build Coastguard Worker def foo(): 4322*da0073e9SAndroid Build Coastguard Worker return [[4]] + [[4, 5]] 4323*da0073e9SAndroid Build Coastguard Worker self.checkScript(foo, ()) 4324*da0073e9SAndroid Build Coastguard Worker 4325*da0073e9SAndroid Build Coastguard Worker def test_file_line_error(self): 4326*da0073e9SAndroid Build Coastguard Worker def foobar(xyz): 4327*da0073e9SAndroid Build Coastguard Worker return torch.blargh(xyz) 4328*da0073e9SAndroid Build Coastguard Worker 4329*da0073e9SAndroid Build Coastguard Worker _, lineno = inspect.getsourcelines(foobar) 4330*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, f'test_jit.py", line {lineno + 1}'): 4331*da0073e9SAndroid Build Coastguard Worker scripted = torch.jit.script(foobar) 4332*da0073e9SAndroid Build Coastguard Worker 4333*da0073e9SAndroid Build Coastguard Worker def test_file_line_error_class_defn(self): 4334*da0073e9SAndroid Build Coastguard Worker class FooBar: 4335*da0073e9SAndroid Build Coastguard Worker def baz(self, xyz): 4336*da0073e9SAndroid Build Coastguard Worker return torch.blargh(xyz) 4337*da0073e9SAndroid Build Coastguard Worker 4338*da0073e9SAndroid Build Coastguard Worker _, lineno = inspect.getsourcelines(FooBar) 4339*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, f'test_jit.py", line {lineno + 2}'): 4340*da0073e9SAndroid Build Coastguard Worker torch.jit.script(FooBar) 4341*da0073e9SAndroid Build Coastguard Worker 4342*da0073e9SAndroid Build Coastguard Worker def test_file_line_graph(self): 4343*da0073e9SAndroid Build Coastguard Worker def foobar(xyz): 4344*da0073e9SAndroid Build Coastguard Worker return torch.neg(xyz) 4345*da0073e9SAndroid Build Coastguard Worker 4346*da0073e9SAndroid Build Coastguard Worker scripted = torch.jit.script(foobar) 4347*da0073e9SAndroid Build Coastguard Worker 4348*da0073e9SAndroid Build Coastguard Worker _, lineno = inspect.getsourcelines(foobar) 4349*da0073e9SAndroid Build Coastguard Worker fc = FileCheck().check(f'test_jit.py:{lineno + 1}:19') 4350*da0073e9SAndroid Build Coastguard Worker fc.run(scripted.graph) 4351*da0073e9SAndroid Build Coastguard Worker fc.run(str(scripted.graph)) 4352*da0073e9SAndroid Build Coastguard Worker 4353*da0073e9SAndroid Build Coastguard Worker def test_file_line_save_load(self): 4354*da0073e9SAndroid Build Coastguard Worker class Scripted(torch.jit.ScriptModule): 4355*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 4356*da0073e9SAndroid Build Coastguard Worker def forward(self, xyz): 4357*da0073e9SAndroid Build Coastguard Worker return torch.neg(xyz) 4358*da0073e9SAndroid Build Coastguard Worker 4359*da0073e9SAndroid Build Coastguard Worker scripted = Scripted() 4360*da0073e9SAndroid Build Coastguard Worker 4361*da0073e9SAndroid Build Coastguard Worker # NB: not using getExportImportCopy because that takes a different 4362*da0073e9SAndroid Build Coastguard Worker # code path that calls CompilationUnit._import rather than 4363*da0073e9SAndroid Build Coastguard Worker # going through the full save/load pathway 4364*da0073e9SAndroid Build Coastguard Worker buffer = scripted.save_to_buffer() 4365*da0073e9SAndroid Build Coastguard Worker bytesio = io.BytesIO(buffer) 4366*da0073e9SAndroid Build Coastguard Worker scripted = torch.jit.load(bytesio) 4367*da0073e9SAndroid Build Coastguard Worker 4368*da0073e9SAndroid Build Coastguard Worker _, lineno = inspect.getsourcelines(Scripted) 4369*da0073e9SAndroid Build Coastguard Worker fc = FileCheck().check(f':{lineno + 3}') 4370*da0073e9SAndroid Build Coastguard Worker fc.run(scripted.graph) 4371*da0073e9SAndroid Build Coastguard Worker fc.run(str(scripted.graph)) 4372*da0073e9SAndroid Build Coastguard Worker 4373*da0073e9SAndroid Build Coastguard Worker def test_file_line_string(self): 4374*da0073e9SAndroid Build Coastguard Worker scripted = torch.jit.CompilationUnit(''' 4375*da0073e9SAndroid Build Coastguard Workerdef foo(xyz): 4376*da0073e9SAndroid Build Coastguard Worker return torch.neg(xyz) 4377*da0073e9SAndroid Build Coastguard Worker ''') 4378*da0073e9SAndroid Build Coastguard Worker 4379*da0073e9SAndroid Build Coastguard Worker fc = FileCheck().check('<string>:3:11') 4380*da0073e9SAndroid Build Coastguard Worker fc.run(scripted.foo.graph) 4381*da0073e9SAndroid Build Coastguard Worker fc.run(str(scripted.foo.graph)) 4382*da0073e9SAndroid Build Coastguard Worker 4383*da0073e9SAndroid Build Coastguard Worker @skipIfCrossRef 4384*da0073e9SAndroid Build Coastguard Worker def test_file_line_trace(self): 4385*da0073e9SAndroid Build Coastguard Worker def foobar(xyz): 4386*da0073e9SAndroid Build Coastguard Worker return torch.neg(xyz) 4387*da0073e9SAndroid Build Coastguard Worker 4388*da0073e9SAndroid Build Coastguard Worker scripted = torch.jit.trace(foobar, (torch.rand(3, 4))) 4389*da0073e9SAndroid Build Coastguard Worker 4390*da0073e9SAndroid Build Coastguard Worker _, lineno = inspect.getsourcelines(foobar) 4391*da0073e9SAndroid Build Coastguard Worker fc = FileCheck().check(f'test_jit.py:{lineno + 1}:0') 4392*da0073e9SAndroid Build Coastguard Worker fc.run(scripted.graph) 4393*da0073e9SAndroid Build Coastguard Worker fc.run(str(scripted.graph)) 4394*da0073e9SAndroid Build Coastguard Worker 4395*da0073e9SAndroid Build Coastguard Worker def test_serialized_source_ranges(self): 4396*da0073e9SAndroid Build Coastguard Worker 4397*da0073e9SAndroid Build Coastguard Worker class FooTest(torch.jit.ScriptModule): 4398*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 4399*da0073e9SAndroid Build Coastguard Worker def forward(self, x, w): 4400*da0073e9SAndroid Build Coastguard Worker return torch.mm(x, w.t()) 4401*da0073e9SAndroid Build Coastguard Worker 4402*da0073e9SAndroid Build Coastguard Worker ft = FooTest() 4403*da0073e9SAndroid Build Coastguard Worker loaded = self.getExportImportCopy(ft) 4404*da0073e9SAndroid Build Coastguard Worker _, lineno = inspect.getsourcelines(FooTest) 4405*da0073e9SAndroid Build Coastguard Worker 4406*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, f'test_jit.py", line {lineno + 3}'): 4407*da0073e9SAndroid Build Coastguard Worker loaded(torch.rand(3, 4), torch.rand(30, 40)) 4408*da0073e9SAndroid Build Coastguard Worker 4409*da0073e9SAndroid Build Coastguard Worker def test_serialized_source_ranges_graph(self): 4410*da0073e9SAndroid Build Coastguard Worker 4411*da0073e9SAndroid Build Coastguard Worker class FooTest3(torch.jit.ScriptModule): 4412*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 4413*da0073e9SAndroid Build Coastguard Worker def forward(self, x, w): 4414*da0073e9SAndroid Build Coastguard Worker return torch.mm(x, w.t()) 4415*da0073e9SAndroid Build Coastguard Worker 4416*da0073e9SAndroid Build Coastguard Worker ft = FooTest3() 4417*da0073e9SAndroid Build Coastguard Worker loaded = self.getExportImportCopy(ft) 4418*da0073e9SAndroid Build Coastguard Worker _, lineno = inspect.getsourcelines(FooTest3) 4419*da0073e9SAndroid Build Coastguard Worker 4420*da0073e9SAndroid Build Coastguard Worker fc = FileCheck().check(f'test_jit.py:{lineno + 3}') 4421*da0073e9SAndroid Build Coastguard Worker fc.run(loaded.graph) 4422*da0073e9SAndroid Build Coastguard Worker 4423*da0073e9SAndroid Build Coastguard Worker def test_serialized_source_ranges2(self): 4424*da0073e9SAndroid Build Coastguard Worker 4425*da0073e9SAndroid Build Coastguard Worker class FooTest2(torch.jit.ScriptModule): 4426*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 4427*da0073e9SAndroid Build Coastguard Worker def forward(self): 4428*da0073e9SAndroid Build Coastguard Worker raise RuntimeError('foo') 4429*da0073e9SAndroid Build Coastguard Worker 4430*da0073e9SAndroid Build Coastguard Worker _, lineno = inspect.getsourcelines(FooTest2) 4431*da0073e9SAndroid Build Coastguard Worker 4432*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(torch.jit.Error, f'test_jit.py", line {lineno + 3}'): 4433*da0073e9SAndroid Build Coastguard Worker ft = FooTest2() 4434*da0073e9SAndroid Build Coastguard Worker loaded = self.getExportImportCopy(ft) 4435*da0073e9SAndroid Build Coastguard Worker loaded() 4436*da0073e9SAndroid Build Coastguard Worker 4437*da0073e9SAndroid Build Coastguard Worker def test_serialized_source_ranges_dont_jitter(self): 4438*da0073e9SAndroid Build Coastguard Worker class FooTest3(torch.jit.ScriptModule): 4439*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 4440*da0073e9SAndroid Build Coastguard Worker def forward(self, lim): 4441*da0073e9SAndroid Build Coastguard Worker first = 1 4442*da0073e9SAndroid Build Coastguard Worker second = 1 4443*da0073e9SAndroid Build Coastguard Worker i = 1 4444*da0073e9SAndroid Build Coastguard Worker somenum = 5 4445*da0073e9SAndroid Build Coastguard Worker dontmutateme = 3 4446*da0073e9SAndroid Build Coastguard Worker third = 0 4447*da0073e9SAndroid Build Coastguard Worker while bool(i < lim): 4448*da0073e9SAndroid Build Coastguard Worker third = first + second 4449*da0073e9SAndroid Build Coastguard Worker first = second 4450*da0073e9SAndroid Build Coastguard Worker second = third 4451*da0073e9SAndroid Build Coastguard Worker j = 0 4452*da0073e9SAndroid Build Coastguard Worker while j < 10: 4453*da0073e9SAndroid Build Coastguard Worker somenum = somenum * 2 4454*da0073e9SAndroid Build Coastguard Worker j = j + 1 4455*da0073e9SAndroid Build Coastguard Worker i = i + j 4456*da0073e9SAndroid Build Coastguard Worker i = i + dontmutateme 4457*da0073e9SAndroid Build Coastguard Worker 4458*da0073e9SAndroid Build Coastguard Worker st = second + third 4459*da0073e9SAndroid Build Coastguard Worker fs = first + second 4460*da0073e9SAndroid Build Coastguard Worker return third, st, fs 4461*da0073e9SAndroid Build Coastguard Worker 4462*da0073e9SAndroid Build Coastguard Worker ft3 = FooTest3() 4463*da0073e9SAndroid Build Coastguard Worker 4464*da0073e9SAndroid Build Coastguard Worker def debug_records_from_mod(self, mod): 4465*da0073e9SAndroid Build Coastguard Worker buffer = io.BytesIO() 4466*da0073e9SAndroid Build Coastguard Worker torch.jit.save(ft3, buffer) 4467*da0073e9SAndroid Build Coastguard Worker buffer.seek(0) 4468*da0073e9SAndroid Build Coastguard Worker archive = zipfile.ZipFile(buffer) 4469*da0073e9SAndroid Build Coastguard Worker files = filter(lambda x: x.startswith('archive/code/'), archive.namelist()) 4470*da0073e9SAndroid Build Coastguard Worker debug_files = list(filter(lambda f: f.endswith('.debug_pkl'), files)) 4471*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(debug_files), 1) 4472*da0073e9SAndroid Build Coastguard Worker debug_file = archive.open(debug_files[0]) 4473*da0073e9SAndroid Build Coastguard Worker return pickle.load(debug_file), buffer 4474*da0073e9SAndroid Build Coastguard Worker 4475*da0073e9SAndroid Build Coastguard Worker records1, buffer = debug_records_from_mod(self, ft3) 4476*da0073e9SAndroid Build Coastguard Worker 4477*da0073e9SAndroid Build Coastguard Worker buffer.seek(0) 4478*da0073e9SAndroid Build Coastguard Worker loaded = torch.jit.load(buffer) 4479*da0073e9SAndroid Build Coastguard Worker records2, buffer = debug_records_from_mod(self, loaded) 4480*da0073e9SAndroid Build Coastguard Worker 4481*da0073e9SAndroid Build Coastguard Worker buffer.seek(0) 4482*da0073e9SAndroid Build Coastguard Worker loaded2 = torch.jit.load(buffer) 4483*da0073e9SAndroid Build Coastguard Worker records3, _ = debug_records_from_mod(self, loaded2) 4484*da0073e9SAndroid Build Coastguard Worker 4485*da0073e9SAndroid Build Coastguard Worker self.assertEqual(records1, records2) 4486*da0073e9SAndroid Build Coastguard Worker self.assertEqual(records2, records3) 4487*da0073e9SAndroid Build Coastguard Worker 4488*da0073e9SAndroid Build Coastguard Worker def test_serialized_source_ranges_no_dups(self): 4489*da0073e9SAndroid Build Coastguard Worker class FooTest3(torch.jit.ScriptModule): 4490*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 4491*da0073e9SAndroid Build Coastguard Worker def forward(self, lim): 4492*da0073e9SAndroid Build Coastguard Worker first = 1 4493*da0073e9SAndroid Build Coastguard Worker second = 1 4494*da0073e9SAndroid Build Coastguard Worker i = 1 4495*da0073e9SAndroid Build Coastguard Worker somenum = 5 4496*da0073e9SAndroid Build Coastguard Worker dontmutateme = 3 4497*da0073e9SAndroid Build Coastguard Worker third = 0 4498*da0073e9SAndroid Build Coastguard Worker while bool(i < lim): 4499*da0073e9SAndroid Build Coastguard Worker third = first + second 4500*da0073e9SAndroid Build Coastguard Worker first = second 4501*da0073e9SAndroid Build Coastguard Worker second = third 4502*da0073e9SAndroid Build Coastguard Worker j = 0 4503*da0073e9SAndroid Build Coastguard Worker while j < 10: 4504*da0073e9SAndroid Build Coastguard Worker somenum = somenum * 2 4505*da0073e9SAndroid Build Coastguard Worker j = j + 1 4506*da0073e9SAndroid Build Coastguard Worker i = i + j 4507*da0073e9SAndroid Build Coastguard Worker i = i + dontmutateme 4508*da0073e9SAndroid Build Coastguard Worker 4509*da0073e9SAndroid Build Coastguard Worker st = second + third 4510*da0073e9SAndroid Build Coastguard Worker fs = first + second 4511*da0073e9SAndroid Build Coastguard Worker return third, st, fs 4512*da0073e9SAndroid Build Coastguard Worker 4513*da0073e9SAndroid Build Coastguard Worker ft3 = FooTest3() 4514*da0073e9SAndroid Build Coastguard Worker 4515*da0073e9SAndroid Build Coastguard Worker def debug_records_from_mod(mod): 4516*da0073e9SAndroid Build Coastguard Worker buffer = io.BytesIO() 4517*da0073e9SAndroid Build Coastguard Worker torch.jit.save(ft3, buffer) 4518*da0073e9SAndroid Build Coastguard Worker buffer.seek(0) 4519*da0073e9SAndroid Build Coastguard Worker archive = zipfile.ZipFile(buffer) 4520*da0073e9SAndroid Build Coastguard Worker files = list(filter(lambda x: x.startswith('archive/code/'), archive.namelist())) 4521*da0073e9SAndroid Build Coastguard Worker debug_files = filter(lambda f: f.endswith('.debug_pkl'), files) 4522*da0073e9SAndroid Build Coastguard Worker debug_files = (archive.open(f) for f in debug_files) 4523*da0073e9SAndroid Build Coastguard Worker debug_files = (pickle.load(f) for f in debug_files) 4524*da0073e9SAndroid Build Coastguard Worker debug_files = (f[2] for f in debug_files) 4525*da0073e9SAndroid Build Coastguard Worker return list(debug_files) 4526*da0073e9SAndroid Build Coastguard Worker 4527*da0073e9SAndroid Build Coastguard Worker debug_files = debug_records_from_mod(ft3) 4528*da0073e9SAndroid Build Coastguard Worker for debug_file in debug_files: 4529*da0073e9SAndroid Build Coastguard Worker for i in range(len(debug_file) - 1): 4530*da0073e9SAndroid Build Coastguard Worker offset, source_range_tag, source_range = debug_file[i] 4531*da0073e9SAndroid Build Coastguard Worker offset2, source_range_tag2, source_range2 = debug_file[i + 1] 4532*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(source_range, source_range2) 4533*da0073e9SAndroid Build Coastguard Worker 4534*da0073e9SAndroid Build Coastguard Worker def test_circular_dependency(self): 4535*da0073e9SAndroid Build Coastguard Worker """ 4536*da0073e9SAndroid Build Coastguard Worker https://github.com/pytorch/pytorch/issues/25871 4537*da0073e9SAndroid Build Coastguard Worker """ 4538*da0073e9SAndroid Build Coastguard Worker class A(torch.jit.ScriptModule): 4539*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 4540*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 4541*da0073e9SAndroid Build Coastguard Worker return x 4542*da0073e9SAndroid Build Coastguard Worker 4543*da0073e9SAndroid Build Coastguard Worker class B(torch.jit.ScriptModule): 4544*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 4545*da0073e9SAndroid Build Coastguard Worker super().__init__() 4546*da0073e9SAndroid Build Coastguard Worker self.foo = torch.nn.ModuleList([A()]) 4547*da0073e9SAndroid Build Coastguard Worker 4548*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 4549*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 4550*da0073e9SAndroid Build Coastguard Worker for f in self.foo: 4551*da0073e9SAndroid Build Coastguard Worker x = f(x) 4552*da0073e9SAndroid Build Coastguard Worker return x 4553*da0073e9SAndroid Build Coastguard Worker 4554*da0073e9SAndroid Build Coastguard Worker class C(torch.jit.ScriptModule): 4555*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 4556*da0073e9SAndroid Build Coastguard Worker super().__init__() 4557*da0073e9SAndroid Build Coastguard Worker self.foo = torch.nn.Sequential(B()) 4558*da0073e9SAndroid Build Coastguard Worker 4559*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 4560*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 4561*da0073e9SAndroid Build Coastguard Worker for f in self.foo: 4562*da0073e9SAndroid Build Coastguard Worker x = f(x) 4563*da0073e9SAndroid Build Coastguard Worker return x 4564*da0073e9SAndroid Build Coastguard Worker self.getExportImportCopy(C()) 4565*da0073e9SAndroid Build Coastguard Worker 4566*da0073e9SAndroid Build Coastguard Worker def test_serialize_long_lines(self): 4567*da0073e9SAndroid Build Coastguard Worker class OrderModuleLong(torch.nn.Module): 4568*da0073e9SAndroid Build Coastguard Worker def forward(self, long_arg_name: List[torch.Tensor]): 4569*da0073e9SAndroid Build Coastguard Worker return [(long_arg_name[1],), (long_arg_name[0].argmax(),)] 4570*da0073e9SAndroid Build Coastguard Worker src = str(torch.jit.script(OrderModuleLong()).code) 4571*da0073e9SAndroid Build Coastguard Worker # make long_arg_name[1] does not get reordered after the argmax 4572*da0073e9SAndroid Build Coastguard Worker FileCheck().check("long_arg_name[1]").check("argmax").run(src) 4573*da0073e9SAndroid Build Coastguard Worker 4574*da0073e9SAndroid Build Coastguard Worker def test_tensor_shape(self): 4575*da0073e9SAndroid Build Coastguard Worker x = torch.empty(34, 56, 78) 4576*da0073e9SAndroid Build Coastguard Worker 4577*da0073e9SAndroid Build Coastguard Worker def f(x): 4578*da0073e9SAndroid Build Coastguard Worker return x.shape 4579*da0073e9SAndroid Build Coastguard Worker 4580*da0073e9SAndroid Build Coastguard Worker self.checkScript(f, (x,)) 4581*da0073e9SAndroid Build Coastguard Worker 4582*da0073e9SAndroid Build Coastguard Worker 4583*da0073e9SAndroid Build Coastguard Worker def test_block_input_grad_in_loop(self): 4584*da0073e9SAndroid Build Coastguard Worker 4585*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, 3, requires_grad=False) 4586*da0073e9SAndroid Build Coastguard Worker y = torch.randn(3, 3, requires_grad=True) 4587*da0073e9SAndroid Build Coastguard Worker 4588*da0073e9SAndroid Build Coastguard Worker def grad_in_loop(x, y): 4589*da0073e9SAndroid Build Coastguard Worker for i in range(100): 4590*da0073e9SAndroid Build Coastguard Worker x = y @ x 4591*da0073e9SAndroid Build Coastguard Worker return x 4592*da0073e9SAndroid Build Coastguard Worker 4593*da0073e9SAndroid Build Coastguard Worker scripted = torch.jit.script(grad_in_loop) 4594*da0073e9SAndroid Build Coastguard Worker outer = scripted.graph_for(x, y) 4595*da0073e9SAndroid Build Coastguard Worker loop = outer.findNode("prim::Loop") 4596*da0073e9SAndroid Build Coastguard Worker loop_block = next(loop.blocks()) 4597*da0073e9SAndroid Build Coastguard Worker param_node = loop_block.paramNode() 4598*da0073e9SAndroid Build Coastguard Worker x_value = list(param_node.outputs())[1] 4599*da0073e9SAndroid Build Coastguard Worker self.assertTrue(x_value.requires_grad()) 4600*da0073e9SAndroid Build Coastguard Worker 4601*da0073e9SAndroid Build Coastguard Worker def test_tensor_grad(self): 4602*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, 4, requires_grad=True) 4603*da0073e9SAndroid Build Coastguard Worker y = torch.randn(3, 4, requires_grad=False) 4604*da0073e9SAndroid Build Coastguard Worker 4605*da0073e9SAndroid Build Coastguard Worker def f_requires_grad(x): 4606*da0073e9SAndroid Build Coastguard Worker return x.requires_grad 4607*da0073e9SAndroid Build Coastguard Worker 4608*da0073e9SAndroid Build Coastguard Worker self.checkScript(f_requires_grad, (x,)) 4609*da0073e9SAndroid Build Coastguard Worker self.checkScript(f_requires_grad, (y,)) 4610*da0073e9SAndroid Build Coastguard Worker 4611*da0073e9SAndroid Build Coastguard Worker def f_grad(x): 4612*da0073e9SAndroid Build Coastguard Worker return x.grad 4613*da0073e9SAndroid Build Coastguard Worker 4614*da0073e9SAndroid Build Coastguard Worker x.sum().backward() 4615*da0073e9SAndroid Build Coastguard Worker self.checkScript(f_grad, (x,)) 4616*da0073e9SAndroid Build Coastguard Worker self.checkScript(f_grad, (y,)) 4617*da0073e9SAndroid Build Coastguard Worker 4618*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "shape analysis is only enabled in Legacy") 4619*da0073e9SAndroid Build Coastguard Worker def test_prim_grad_undefined(self): 4620*da0073e9SAndroid Build Coastguard Worker 4621*da0073e9SAndroid Build Coastguard Worker x = torch.ones(2) 4622*da0073e9SAndroid Build Coastguard Worker 4623*da0073e9SAndroid Build Coastguard Worker def f_grad(x): 4624*da0073e9SAndroid Build Coastguard Worker return x.grad 4625*da0073e9SAndroid Build Coastguard Worker 4626*da0073e9SAndroid Build Coastguard Worker scripted = self.checkScript(f_grad, (x,)) 4627*da0073e9SAndroid Build Coastguard Worker g = scripted.graph_for(x) 4628*da0073e9SAndroid Build Coastguard Worker 4629*da0073e9SAndroid Build Coastguard Worker prim_grad_node = g.findNode("prim::grad") 4630*da0073e9SAndroid Build Coastguard Worker self.assertTrue(next(prim_grad_node.outputs()).type().undefined() is None) 4631*da0073e9SAndroid Build Coastguard Worker 4632*da0073e9SAndroid Build Coastguard Worker def test_tensor_data(self): 4633*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, 4, requires_grad=True) 4634*da0073e9SAndroid Build Coastguard Worker y = torch.randn(4, 5) 4635*da0073e9SAndroid Build Coastguard Worker 4636*da0073e9SAndroid Build Coastguard Worker def f_data(x): 4637*da0073e9SAndroid Build Coastguard Worker return x.data 4638*da0073e9SAndroid Build Coastguard Worker 4639*da0073e9SAndroid Build Coastguard Worker scripted_f_data = torch.jit.script(f_data) 4640*da0073e9SAndroid Build Coastguard Worker 4641*da0073e9SAndroid Build Coastguard Worker scripted_x = scripted_f_data(x) 4642*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scripted_x, f_data(x)) 4643*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scripted_x.requires_grad, False) 4644*da0073e9SAndroid Build Coastguard Worker 4645*da0073e9SAndroid Build Coastguard Worker scripted_y = scripted_f_data(y) 4646*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scripted_y, f_data(y)) 4647*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scripted_x.requires_grad, False) 4648*da0073e9SAndroid Build Coastguard Worker 4649*da0073e9SAndroid Build Coastguard Worker def test_tensor_dtype(self): 4650*da0073e9SAndroid Build Coastguard Worker x_byte = torch.empty(34, 56, 78, dtype=torch.uint8) 4651*da0073e9SAndroid Build Coastguard Worker x_long = torch.empty(34, 56, 78, dtype=torch.long) 4652*da0073e9SAndroid Build Coastguard Worker x_float32 = torch.empty(34, 56, 78, dtype=torch.float32) 4653*da0073e9SAndroid Build Coastguard Worker 4654*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 4655*da0073e9SAndroid Build Coastguard Worker def byte(x): 4656*da0073e9SAndroid Build Coastguard Worker return x.dtype == torch.uint8 4657*da0073e9SAndroid Build Coastguard Worker 4658*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 4659*da0073e9SAndroid Build Coastguard Worker def long(x): 4660*da0073e9SAndroid Build Coastguard Worker return x.dtype == torch.long 4661*da0073e9SAndroid Build Coastguard Worker 4662*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 4663*da0073e9SAndroid Build Coastguard Worker def float32(x): 4664*da0073e9SAndroid Build Coastguard Worker return x.dtype == torch.float32 4665*da0073e9SAndroid Build Coastguard Worker 4666*da0073e9SAndroid Build Coastguard Worker self.assertTrue(byte(x_byte)) 4667*da0073e9SAndroid Build Coastguard Worker self.assertFalse(byte(x_long)) 4668*da0073e9SAndroid Build Coastguard Worker self.assertFalse(byte(x_float32)) 4669*da0073e9SAndroid Build Coastguard Worker self.assertFalse(long(x_byte)) 4670*da0073e9SAndroid Build Coastguard Worker self.assertTrue(long(x_long)) 4671*da0073e9SAndroid Build Coastguard Worker self.assertFalse(long(x_float32)) 4672*da0073e9SAndroid Build Coastguard Worker self.assertFalse(float32(x_byte)) 4673*da0073e9SAndroid Build Coastguard Worker self.assertFalse(float32(x_long)) 4674*da0073e9SAndroid Build Coastguard Worker self.assertTrue(float32(x_float32)) 4675*da0073e9SAndroid Build Coastguard Worker 4676*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not RUN_CUDA, "device tests require CUDA") 4677*da0073e9SAndroid Build Coastguard Worker def test_tensor_device(self): 4678*da0073e9SAndroid Build Coastguard Worker cpu = torch.empty(34, 56, 78, device='cpu') 4679*da0073e9SAndroid Build Coastguard Worker gpu = torch.empty(34, 56, 78, device='cuda') 4680*da0073e9SAndroid Build Coastguard Worker 4681*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 4682*da0073e9SAndroid Build Coastguard Worker def same_device(x, y): 4683*da0073e9SAndroid Build Coastguard Worker return x.device == y.device 4684*da0073e9SAndroid Build Coastguard Worker 4685*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same_device(cpu, cpu)) 4686*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same_device(gpu, gpu)) 4687*da0073e9SAndroid Build Coastguard Worker self.assertFalse(same_device(cpu, gpu)) 4688*da0073e9SAndroid Build Coastguard Worker 4689*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not RUN_CUDA, "device tests require CUDA") 4690*da0073e9SAndroid Build Coastguard Worker def test_tensor_to_device(self): 4691*da0073e9SAndroid Build Coastguard Worker def to_device(x): 4692*da0073e9SAndroid Build Coastguard Worker return x.to(device="cuda").to(device=torch.device("cpu")) 4693*da0073e9SAndroid Build Coastguard Worker 4694*da0073e9SAndroid Build Coastguard Worker self.checkScript(to_device, (torch.ones(3, 4),)) 4695*da0073e9SAndroid Build Coastguard Worker 4696*da0073e9SAndroid Build Coastguard Worker def test_tensor_to_cpu(self): 4697*da0073e9SAndroid Build Coastguard Worker def to_cpu(x): 4698*da0073e9SAndroid Build Coastguard Worker return x.cpu() 4699*da0073e9SAndroid Build Coastguard Worker 4700*da0073e9SAndroid Build Coastguard Worker x = torch.ones(3, 4) 4701*da0073e9SAndroid Build Coastguard Worker script_fn = torch.jit.script(to_cpu) 4702*da0073e9SAndroid Build Coastguard Worker self.assertEqual(to_cpu(x).device, script_fn(x).device) 4703*da0073e9SAndroid Build Coastguard Worker self.checkScript(to_cpu, (x,)) 4704*da0073e9SAndroid Build Coastguard Worker 4705*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not RUN_CUDA, "device tests require CUDA") 4706*da0073e9SAndroid Build Coastguard Worker def test_tensor_to_cuda(self): 4707*da0073e9SAndroid Build Coastguard Worker def to_cuda(x): 4708*da0073e9SAndroid Build Coastguard Worker return x.cuda() 4709*da0073e9SAndroid Build Coastguard Worker 4710*da0073e9SAndroid Build Coastguard Worker x = torch.ones(3, 4) 4711*da0073e9SAndroid Build Coastguard Worker script_fn = torch.jit.script(to_cuda) 4712*da0073e9SAndroid Build Coastguard Worker self.assertEqual(to_cuda(x).device, script_fn(x).device) 4713*da0073e9SAndroid Build Coastguard Worker self.checkScript(to_cuda, (x,)) 4714*da0073e9SAndroid Build Coastguard Worker 4715*da0073e9SAndroid Build Coastguard Worker def test_generic_list_errors(self): 4716*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "previously matched to type"): 4717*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 4718*da0073e9SAndroid Build Coastguard Worker def foo(x): 4719*da0073e9SAndroid Build Coastguard Worker return [[x]] + [[1]] 4720*da0073e9SAndroid Build Coastguard Worker 4721*da0073e9SAndroid Build Coastguard Worker def test_script_cu(self): 4722*da0073e9SAndroid Build Coastguard Worker cu = torch.jit.CompilationUnit(''' 4723*da0073e9SAndroid Build Coastguard Worker def foo(a): 4724*da0073e9SAndroid Build Coastguard Worker b = a 4725*da0073e9SAndroid Build Coastguard Worker return b 4726*da0073e9SAndroid Build Coastguard Worker ''') 4727*da0073e9SAndroid Build Coastguard Worker a = Variable(torch.rand(1)) 4728*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a, cu.foo(a)) 4729*da0073e9SAndroid Build Coastguard Worker 4730*da0073e9SAndroid Build Coastguard Worker # because the compilation unit ingests python strings 4731*da0073e9SAndroid Build Coastguard Worker # to use an escape sequence escape the backslash (\\n = \n) 4732*da0073e9SAndroid Build Coastguard Worker def test_string_cu(self): 4733*da0073e9SAndroid Build Coastguard Worker cu = torch.jit.CompilationUnit(''' 4734*da0073e9SAndroid Build Coastguard Worker def foo(a): 4735*da0073e9SAndroid Build Coastguard Worker print(a, """a\\n\tb\\n""", 2, "a\ 4736*da0073e9SAndroid Build Coastguard Workera") 4737*da0073e9SAndroid Build Coastguard Worker return a 4738*da0073e9SAndroid Build Coastguard Worker ''') 4739*da0073e9SAndroid Build Coastguard Worker FileCheck().check("aa").check("a\\n\\tb\\n").run(str(cu.foo.graph)) 4740*da0073e9SAndroid Build Coastguard Worker 4741*da0073e9SAndroid Build Coastguard Worker def test_function_compilation_caching(self): 4742*da0073e9SAndroid Build Coastguard Worker def fun(): 4743*da0073e9SAndroid Build Coastguard Worker return 1 + 2 4744*da0073e9SAndroid Build Coastguard Worker 4745*da0073e9SAndroid Build Coastguard Worker fun_compiled = torch.jit.script(fun) 4746*da0073e9SAndroid Build Coastguard Worker # python wrapper around the script function is a different pointer, 4747*da0073e9SAndroid Build Coastguard Worker # but the underlying script function graph is the same 4748*da0073e9SAndroid Build Coastguard Worker self.assertIs(fun_compiled.graph, torch.jit.script(fun).graph) 4749*da0073e9SAndroid Build Coastguard Worker 4750*da0073e9SAndroid Build Coastguard Worker def fun(): 4751*da0073e9SAndroid Build Coastguard Worker return 3 + 4 4752*da0073e9SAndroid Build Coastguard Worker 4753*da0073e9SAndroid Build Coastguard Worker num_ref_counts = sys.getrefcount(fun) 4754*da0073e9SAndroid Build Coastguard Worker 4755*da0073e9SAndroid Build Coastguard Worker # caching doesn't get tripped up by same qualname 4756*da0073e9SAndroid Build Coastguard Worker fun_compiled_2 = torch.jit.script(fun) 4757*da0073e9SAndroid Build Coastguard Worker self.assertIsNot(fun_compiled, fun_compiled_2) 4758*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fun_compiled_2(), 7) 4759*da0073e9SAndroid Build Coastguard Worker 4760*da0073e9SAndroid Build Coastguard Worker # caching doesnt increase refcounts to function (holds weak reference) 4761*da0073e9SAndroid Build Coastguard Worker self.assertTrue(sys.getrefcount(fun), num_ref_counts) 4762*da0073e9SAndroid Build Coastguard Worker 4763*da0073e9SAndroid Build Coastguard Worker def test_string_ops(self): 4764*da0073e9SAndroid Build Coastguard Worker def foo(): 4765*da0073e9SAndroid Build Coastguard Worker a = "a" + "b" 4766*da0073e9SAndroid Build Coastguard Worker return a + a, "ab" == "b", "ab" != "b", "ab" == "ab", "ab" != "ab" 4767*da0073e9SAndroid Build Coastguard Worker 4768*da0073e9SAndroid Build Coastguard Worker self.checkScript(foo, ()) 4769*da0073e9SAndroid Build Coastguard Worker 4770*da0073e9SAndroid Build Coastguard Worker def test_string_sorted(self): 4771*da0073e9SAndroid Build Coastguard Worker def foo(strs: List[str]): 4772*da0073e9SAndroid Build Coastguard Worker return sorted(strs) 4773*da0073e9SAndroid Build Coastguard Worker 4774*da0073e9SAndroid Build Coastguard Worker FileCheck() \ 4775*da0073e9SAndroid Build Coastguard Worker .check("graph") \ 4776*da0073e9SAndroid Build Coastguard Worker .check_next("str[] = aten::sorted") \ 4777*da0073e9SAndroid Build Coastguard Worker .check_next("return") \ 4778*da0073e9SAndroid Build Coastguard Worker .run(str(torch.jit.script(foo).graph)) 4779*da0073e9SAndroid Build Coastguard Worker 4780*da0073e9SAndroid Build Coastguard Worker inputs = ["str3", "str2", "str1"] 4781*da0073e9SAndroid Build Coastguard Worker self.checkScript(foo, (inputs,)) 4782*da0073e9SAndroid Build Coastguard Worker 4783*da0073e9SAndroid Build Coastguard Worker def test_string_sort(self): 4784*da0073e9SAndroid Build Coastguard Worker def foo(strs: List[str]): 4785*da0073e9SAndroid Build Coastguard Worker strs.sort() 4786*da0073e9SAndroid Build Coastguard Worker return strs 4787*da0073e9SAndroid Build Coastguard Worker 4788*da0073e9SAndroid Build Coastguard Worker inputs = ["str3", "str2", "str1"] 4789*da0073e9SAndroid Build Coastguard Worker self.checkScript(foo, (inputs,)) 4790*da0073e9SAndroid Build Coastguard Worker 4791*da0073e9SAndroid Build Coastguard Worker def test_tuple_sorted(self): 4792*da0073e9SAndroid Build Coastguard Worker def foo(tups: List[Tuple[int, int]]): 4793*da0073e9SAndroid Build Coastguard Worker return sorted(tups) 4794*da0073e9SAndroid Build Coastguard Worker 4795*da0073e9SAndroid Build Coastguard Worker inputs = [(1, 2), (0, 2), (1, 3)] 4796*da0073e9SAndroid Build Coastguard Worker self.checkScript(foo, (inputs,)) 4797*da0073e9SAndroid Build Coastguard Worker 4798*da0073e9SAndroid Build Coastguard Worker def test_tuple_sort(self): 4799*da0073e9SAndroid Build Coastguard Worker def foo(tups: List[Tuple[int, int]]): 4800*da0073e9SAndroid Build Coastguard Worker tups.sort() 4801*da0073e9SAndroid Build Coastguard Worker return tups 4802*da0073e9SAndroid Build Coastguard Worker 4803*da0073e9SAndroid Build Coastguard Worker inputs = [(1, 2), (0, 2), (1, 3)] 4804*da0073e9SAndroid Build Coastguard Worker self.checkScript(foo, (inputs,)) 4805*da0073e9SAndroid Build Coastguard Worker 4806*da0073e9SAndroid Build Coastguard Worker def test_tuple_sort_reverse(self): 4807*da0073e9SAndroid Build Coastguard Worker def foo(tups: List[Tuple[int, int]]): 4808*da0073e9SAndroid Build Coastguard Worker tups.sort(reverse=True) 4809*da0073e9SAndroid Build Coastguard Worker return tups 4810*da0073e9SAndroid Build Coastguard Worker 4811*da0073e9SAndroid Build Coastguard Worker inputs = [(1, 2), (0, 2), (1, 3)] 4812*da0073e9SAndroid Build Coastguard Worker self.checkScript(foo, (inputs,)) 4813*da0073e9SAndroid Build Coastguard Worker 4814*da0073e9SAndroid Build Coastguard Worker def test_tuple_unsortable_element_type(self): 4815*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 4816*da0073e9SAndroid Build Coastguard Worker def foo(): 4817*da0073e9SAndroid Build Coastguard Worker tups = [({1: 2}, {2: 3})] 4818*da0073e9SAndroid Build Coastguard Worker tups.sort() 4819*da0073e9SAndroid Build Coastguard Worker return tups 4820*da0073e9SAndroid Build Coastguard Worker 4821*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegexWithHighlight(RuntimeError, "are not sortable", "tups.sort"): 4822*da0073e9SAndroid Build Coastguard Worker foo() 4823*da0073e9SAndroid Build Coastguard Worker 4824*da0073e9SAndroid Build Coastguard Worker def test_tuple_unsortable_diff_type(self): 4825*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 4826*da0073e9SAndroid Build Coastguard Worker def foo(inputs: List[Any]): 4827*da0073e9SAndroid Build Coastguard Worker inputs.sort() 4828*da0073e9SAndroid Build Coastguard Worker return inputs 4829*da0073e9SAndroid Build Coastguard Worker 4830*da0073e9SAndroid Build Coastguard Worker inputs = [(1, 2), ("foo", "bar")] 4831*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegexWithHighlight(RuntimeError, "Only values of same type can be compared", "inputs.sort"): 4832*da0073e9SAndroid Build Coastguard Worker foo(inputs) 4833*da0073e9SAndroid Build Coastguard Worker 4834*da0073e9SAndroid Build Coastguard Worker def test_tuple_nested_sort(self): 4835*da0073e9SAndroid Build Coastguard Worker def foo(inputs: List[Tuple[int, Tuple[int, str]]]): 4836*da0073e9SAndroid Build Coastguard Worker inputs.sort() 4837*da0073e9SAndroid Build Coastguard Worker return inputs 4838*da0073e9SAndroid Build Coastguard Worker 4839*da0073e9SAndroid Build Coastguard Worker inputs = [(1, (2, "foo")), (1, (2, "bar")), (1, (0, "bar"))] 4840*da0073e9SAndroid Build Coastguard Worker self.checkScript(foo, (inputs,)) 4841*da0073e9SAndroid Build Coastguard Worker 4842*da0073e9SAndroid Build Coastguard Worker def test_tuple_unsortable_nested_diff_type(self): 4843*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 4844*da0073e9SAndroid Build Coastguard Worker def foo(inputs: List[Any]): 4845*da0073e9SAndroid Build Coastguard Worker inputs.sort() 4846*da0073e9SAndroid Build Coastguard Worker return inputs 4847*da0073e9SAndroid Build Coastguard Worker 4848*da0073e9SAndroid Build Coastguard Worker inputs = [(1, (2, 3)), (2, ("foo", "bar"))] 4849*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegexWithHighlight(RuntimeError, "Only values of same type can be compared", "inputs.sort"): 4850*da0073e9SAndroid Build Coastguard Worker foo(inputs) 4851*da0073e9SAndroid Build Coastguard Worker 4852*da0073e9SAndroid Build Coastguard Worker def test_string_new_line(self): 4853*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "expected a valid token*"): 4854*da0073e9SAndroid Build Coastguard Worker torch.jit.CompilationUnit(''' 4855*da0073e9SAndroid Build Coastguard Worker def test_while(a): 4856*da0073e9SAndroid Build Coastguard Worker print(" 4857*da0073e9SAndroid Build Coastguard Worker a") 4858*da0073e9SAndroid Build Coastguard Worker return a 4859*da0073e9SAndroid Build Coastguard Worker ''') 4860*da0073e9SAndroid Build Coastguard Worker 4861*da0073e9SAndroid Build Coastguard Worker def test_string_single_escape(self): 4862*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "expected a valid token*"): 4863*da0073e9SAndroid Build Coastguard Worker torch.jit.CompilationUnit(''' 4864*da0073e9SAndroid Build Coastguard Worker def test_while(a): 4865*da0073e9SAndroid Build Coastguard Worker print("\\") 4866*da0073e9SAndroid Build Coastguard Worker return a 4867*da0073e9SAndroid Build Coastguard Worker ''') 4868*da0073e9SAndroid Build Coastguard Worker 4869*da0073e9SAndroid Build Coastguard Worker def test_script_annotation(self): 4870*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 4871*da0073e9SAndroid Build Coastguard Worker def foo(a): 4872*da0073e9SAndroid Build Coastguard Worker return a + a + a 4873*da0073e9SAndroid Build Coastguard Worker s = Variable(torch.rand(2)) 4874*da0073e9SAndroid Build Coastguard Worker self.assertEqual(s + s + s, foo(s)) 4875*da0073e9SAndroid Build Coastguard Worker 4876*da0073e9SAndroid Build Coastguard Worker def test_torch_pow(self): 4877*da0073e9SAndroid Build Coastguard Worker def func(a, b): 4878*da0073e9SAndroid Build Coastguard Worker return pow(a, b) 4879*da0073e9SAndroid Build Coastguard Worker 4880*da0073e9SAndroid Build Coastguard Worker def func2(a, b, c, d): 4881*da0073e9SAndroid Build Coastguard Worker return pow(pow(c + a, b), d) 4882*da0073e9SAndroid Build Coastguard Worker 4883*da0073e9SAndroid Build Coastguard Worker def func3(a : int, b : float): 4884*da0073e9SAndroid Build Coastguard Worker # type: (int, float) -> float 4885*da0073e9SAndroid Build Coastguard Worker return pow(a, b) 4886*da0073e9SAndroid Build Coastguard Worker 4887*da0073e9SAndroid Build Coastguard Worker def func4(): 4888*da0073e9SAndroid Build Coastguard Worker # type: () -> float 4889*da0073e9SAndroid Build Coastguard Worker return pow(2, -2) 4890*da0073e9SAndroid Build Coastguard Worker 4891*da0073e9SAndroid Build Coastguard Worker def func5(x, y): 4892*da0073e9SAndroid Build Coastguard Worker return pow(x.item(), y.item()) 4893*da0073e9SAndroid Build Coastguard Worker 4894*da0073e9SAndroid Build Coastguard Worker def func6(a : int, b : int): 4895*da0073e9SAndroid Build Coastguard Worker # type: (int, int) -> float 4896*da0073e9SAndroid Build Coastguard Worker return pow(a, b) 4897*da0073e9SAndroid Build Coastguard Worker 4898*da0073e9SAndroid Build Coastguard Worker a = torch.rand(1) 4899*da0073e9SAndroid Build Coastguard Worker b = torch.rand(1) 4900*da0073e9SAndroid Build Coastguard Worker c = torch.rand(1) 4901*da0073e9SAndroid Build Coastguard Worker d = torch.rand(1) 4902*da0073e9SAndroid Build Coastguard Worker self.checkScript(func, (a, b)) 4903*da0073e9SAndroid Build Coastguard Worker self.checkScript(func2, (a, b, c, d)) 4904*da0073e9SAndroid Build Coastguard Worker self.checkScript(func3, (4, -0.5)) 4905*da0073e9SAndroid Build Coastguard Worker self.checkScript(func4, ()) 4906*da0073e9SAndroid Build Coastguard Worker self.checkScript(func6, (2, 4)) 4907*da0073e9SAndroid Build Coastguard Worker 4908*da0073e9SAndroid Build Coastguard Worker inputs = [torch.tensor(2), torch.tensor(-2), torch.tensor(.5), torch.tensor(.2)] 4909*da0073e9SAndroid Build Coastguard Worker for x in inputs: 4910*da0073e9SAndroid Build Coastguard Worker for y in inputs: 4911*da0073e9SAndroid Build Coastguard Worker if x < 0: 4912*da0073e9SAndroid Build Coastguard Worker continue 4913*da0073e9SAndroid Build Coastguard Worker else: 4914*da0073e9SAndroid Build Coastguard Worker self.checkScript(func5, (x, y)) 4915*da0073e9SAndroid Build Coastguard Worker 4916*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not RUN_CUDA, "device tests require CUDA") 4917*da0073e9SAndroid Build Coastguard Worker def test_pow_scalar_backward_cuda(self): 4918*da0073e9SAndroid Build Coastguard Worker # see that scalar exponent works with cuda base (#19253) 4919*da0073e9SAndroid Build Coastguard Worker with enable_profiling_mode_for_profiling_tests(): 4920*da0073e9SAndroid Build Coastguard Worker for dtype in [torch.float, torch.double]: 4921*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 4922*da0073e9SAndroid Build Coastguard Worker def func(a, b): 4923*da0073e9SAndroid Build Coastguard Worker # type: (Tensor, float) -> Tensor 4924*da0073e9SAndroid Build Coastguard Worker return (a * 2) ** b 4925*da0073e9SAndroid Build Coastguard Worker 4926*da0073e9SAndroid Build Coastguard Worker a = torch.rand(1, requires_grad=True, device='cuda', dtype=dtype) 4927*da0073e9SAndroid Build Coastguard Worker func(a, 1, profile_and_replay=True).backward() 4928*da0073e9SAndroid Build Coastguard Worker 4929*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 4930*da0073e9SAndroid Build Coastguard Worker def func(a, b): 4931*da0073e9SAndroid Build Coastguard Worker # type: (float, Tensor) -> Tensor 4932*da0073e9SAndroid Build Coastguard Worker return a ** (b * 2 + 1) 4933*da0073e9SAndroid Build Coastguard Worker 4934*da0073e9SAndroid Build Coastguard Worker a = torch.rand(1, requires_grad=True, device='cuda', dtype=dtype) 4935*da0073e9SAndroid Build Coastguard Worker func(2, a, profile_and_replay=True).backward() 4936*da0073e9SAndroid Build Coastguard Worker 4937*da0073e9SAndroid Build Coastguard Worker def _check_code(self, code_str, fn_name, inputs): 4938*da0073e9SAndroid Build Coastguard Worker scope = {} 4939*da0073e9SAndroid Build Coastguard Worker exec(code_str, globals(), scope) 4940*da0073e9SAndroid Build Coastguard Worker cu = torch.jit.CompilationUnit(code_str) 4941*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cu.func(*inputs), scope[fn_name](*inputs)) 4942*da0073e9SAndroid Build Coastguard Worker 4943*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not RUN_CUDA, 'no CUDA') 4944*da0073e9SAndroid Build Coastguard Worker def test_scriptmodule_releases_tensors_cuda(self): 4945*da0073e9SAndroid Build Coastguard Worker with enable_profiling_mode_for_profiling_tests(): 4946*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 4947*da0073e9SAndroid Build Coastguard Worker def fn(x, y): 4948*da0073e9SAndroid Build Coastguard Worker return x.sigmoid() * y.tanh() 4949*da0073e9SAndroid Build Coastguard Worker 4950*da0073e9SAndroid Build Coastguard Worker def test(backward=False): 4951*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, 3, dtype=torch.double, device='cuda', requires_grad=True) 4952*da0073e9SAndroid Build Coastguard Worker y = torch.randn(3, 3, dtype=torch.double, device='cuda', requires_grad=True) 4953*da0073e9SAndroid Build Coastguard Worker out = fn(x, y, profile_and_replay=True) 4954*da0073e9SAndroid Build Coastguard Worker if backward: 4955*da0073e9SAndroid Build Coastguard Worker out.sum().backward() 4956*da0073e9SAndroid Build Coastguard Worker 4957*da0073e9SAndroid Build Coastguard Worker with self.assertLeaksNoCudaTensors(): 4958*da0073e9SAndroid Build Coastguard Worker test() 4959*da0073e9SAndroid Build Coastguard Worker test() 4960*da0073e9SAndroid Build Coastguard Worker test() 4961*da0073e9SAndroid Build Coastguard Worker 4962*da0073e9SAndroid Build Coastguard Worker if GRAPH_EXECUTOR != ProfilingMode.SIMPLE: 4963*da0073e9SAndroid Build Coastguard Worker with self.assertLeaksNoCudaTensors(): 4964*da0073e9SAndroid Build Coastguard Worker test(backward=True) 4965*da0073e9SAndroid Build Coastguard Worker test(backward=True) 4966*da0073e9SAndroid Build Coastguard Worker test(backward=True) 4967*da0073e9SAndroid Build Coastguard Worker 4968*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("Not a TorchDynamo suitable test") 4969*da0073e9SAndroid Build Coastguard Worker def test_index(self): 4970*da0073e9SAndroid Build Coastguard Worker def consec(size, start=0): 4971*da0073e9SAndroid Build Coastguard Worker numel = torch.tensor(size).prod().item() 4972*da0073e9SAndroid Build Coastguard Worker return torch.arange(numel).view(size) 4973*da0073e9SAndroid Build Coastguard Worker 4974*da0073e9SAndroid Build Coastguard Worker def consec_list(size): 4975*da0073e9SAndroid Build Coastguard Worker return list(range(size)) 4976*da0073e9SAndroid Build Coastguard Worker 4977*da0073e9SAndroid Build Coastguard Worker def random_string(size): 4978*da0073e9SAndroid Build Coastguard Worker letters = string.ascii_lowercase 4979*da0073e9SAndroid Build Coastguard Worker return "".join(random.choice(letters) for i in range(size)) 4980*da0073e9SAndroid Build Coastguard Worker 4981*da0073e9SAndroid Build Coastguard Worker def check_indexing(indexing, tensor): 4982*da0073e9SAndroid Build Coastguard Worker template = dedent(""" 4983*da0073e9SAndroid Build Coastguard Worker def func(x): 4984*da0073e9SAndroid Build Coastguard Worker return x{} 4985*da0073e9SAndroid Build Coastguard Worker """) 4986*da0073e9SAndroid Build Coastguard Worker 4987*da0073e9SAndroid Build Coastguard Worker self._check_code(template.format(indexing), "func", [tensor]) 4988*da0073e9SAndroid Build Coastguard Worker 4989*da0073e9SAndroid Build Coastguard Worker def check_dynamic_indexing(indexing, tensor, value1, value2): 4990*da0073e9SAndroid Build Coastguard Worker value1 = torch.tensor(value1) 4991*da0073e9SAndroid Build Coastguard Worker value2 = torch.tensor(value2) 4992*da0073e9SAndroid Build Coastguard Worker 4993*da0073e9SAndroid Build Coastguard Worker template = dedent(""" 4994*da0073e9SAndroid Build Coastguard Worker def func(x, value1, value2): 4995*da0073e9SAndroid Build Coastguard Worker i = int(value1) 4996*da0073e9SAndroid Build Coastguard Worker j = int(value2) 4997*da0073e9SAndroid Build Coastguard Worker return x{} 4998*da0073e9SAndroid Build Coastguard Worker """) 4999*da0073e9SAndroid Build Coastguard Worker 5000*da0073e9SAndroid Build Coastguard Worker self._check_code(template.format(indexing), "func", [tensor, value1, value2]) 5001*da0073e9SAndroid Build Coastguard Worker 5002*da0073e9SAndroid Build Coastguard Worker # Torchscript assumes type Tensor by default, so we need this explicit 5003*da0073e9SAndroid Build Coastguard Worker # declaration. 5004*da0073e9SAndroid Build Coastguard Worker def check_indexing_list_int(indexing, list): 5005*da0073e9SAndroid Build Coastguard Worker template = dedent(""" 5006*da0073e9SAndroid Build Coastguard Worker def func(x): 5007*da0073e9SAndroid Build Coastguard Worker # type: (List[int]) -> Any 5008*da0073e9SAndroid Build Coastguard Worker return x{} 5009*da0073e9SAndroid Build Coastguard Worker """) 5010*da0073e9SAndroid Build Coastguard Worker 5011*da0073e9SAndroid Build Coastguard Worker self._check_code(template.format(indexing), "func", [list]) 5012*da0073e9SAndroid Build Coastguard Worker 5013*da0073e9SAndroid Build Coastguard Worker def check_indexing_str(indexing, str): 5014*da0073e9SAndroid Build Coastguard Worker template = dedent(""" 5015*da0073e9SAndroid Build Coastguard Worker def func(x): 5016*da0073e9SAndroid Build Coastguard Worker # type: (str) -> Any 5017*da0073e9SAndroid Build Coastguard Worker return x{} 5018*da0073e9SAndroid Build Coastguard Worker """) 5019*da0073e9SAndroid Build Coastguard Worker 5020*da0073e9SAndroid Build Coastguard Worker self._check_code(template.format(indexing), "func", [str]) 5021*da0073e9SAndroid Build Coastguard Worker 5022*da0073e9SAndroid Build Coastguard Worker # basic slices 5023*da0073e9SAndroid Build Coastguard Worker check_indexing('[0]', consec((3, 3))) 5024*da0073e9SAndroid Build Coastguard Worker check_indexing('[1]', consec((3, 3), 10)) 5025*da0073e9SAndroid Build Coastguard Worker check_indexing('[2]', consec((3, 3), 19)) 5026*da0073e9SAndroid Build Coastguard Worker check_indexing('[2]', consec((3,))) 5027*da0073e9SAndroid Build Coastguard Worker check_indexing('[-1]', consec((3, 3), 19)) 5028*da0073e9SAndroid Build Coastguard Worker check_indexing('[0:2]', consec((3, 3, 3))) 5029*da0073e9SAndroid Build Coastguard Worker check_indexing('[1:-1]', consec((3, 3, 3))) 5030*da0073e9SAndroid Build Coastguard Worker check_indexing('[-3:-1]', consec((6, 3))) 5031*da0073e9SAndroid Build Coastguard Worker check_indexing('[1:]', consec((3, 3))) 5032*da0073e9SAndroid Build Coastguard Worker check_indexing('[:1]', consec((3, 3))) 5033*da0073e9SAndroid Build Coastguard Worker check_indexing('[:]', consec((3, 2))) 5034*da0073e9SAndroid Build Coastguard Worker 5035*da0073e9SAndroid Build Coastguard Worker # multi-dim: indexes 5036*da0073e9SAndroid Build Coastguard Worker check_indexing('[0, 1]', consec((3, 3))) 5037*da0073e9SAndroid Build Coastguard Worker check_indexing('[0, 1]', consec((3, 3, 2))) 5038*da0073e9SAndroid Build Coastguard Worker check_indexing('[1, 0, 2]', consec((3, 3, 3))) 5039*da0073e9SAndroid Build Coastguard Worker check_indexing('[2, -1]', consec((3, 3))) 5040*da0073e9SAndroid Build Coastguard Worker 5041*da0073e9SAndroid Build Coastguard Worker # multi-dim: mixed slicing and indexing 5042*da0073e9SAndroid Build Coastguard Worker check_indexing('[0, 1:2]', consec((3, 3))) 5043*da0073e9SAndroid Build Coastguard Worker check_indexing('[0, :1]', consec((3, 3, 2))) 5044*da0073e9SAndroid Build Coastguard Worker check_indexing('[1, 2:]', consec((3, 3, 3))) 5045*da0073e9SAndroid Build Coastguard Worker check_indexing('[-1, 1:, 0]', consec((3, 3, 3, 3))) 5046*da0073e9SAndroid Build Coastguard Worker check_indexing('[1:, -1, 0]', consec((3, 3, 3, 3))) 5047*da0073e9SAndroid Build Coastguard Worker check_indexing('[-1, 2:, 1:2]', consec((3, 3, 3, 3))) 5048*da0073e9SAndroid Build Coastguard Worker check_indexing('[-1, 1:, 0]', consec((3, 3, 3, 3))) 5049*da0073e9SAndroid Build Coastguard Worker check_indexing('[-1, :, 0, 2]', consec((3, 3, 3, 3))) 5050*da0073e9SAndroid Build Coastguard Worker 5051*da0073e9SAndroid Build Coastguard Worker # zero-sized slices 5052*da0073e9SAndroid Build Coastguard Worker check_indexing('[0:0]', consec((2, 2))) 5053*da0073e9SAndroid Build Coastguard Worker check_indexing('[0:0, 1]', consec((3, 3))) 5054*da0073e9SAndroid Build Coastguard Worker 5055*da0073e9SAndroid Build Coastguard Worker # trivial expression usage 5056*da0073e9SAndroid Build Coastguard Worker check_indexing('[1+1]', consec((3, 3))) 5057*da0073e9SAndroid Build Coastguard Worker check_indexing('[1:(0 + 2)]', consec((3, 3, 3))) 5058*da0073e9SAndroid Build Coastguard Worker 5059*da0073e9SAndroid Build Coastguard Worker # None for new dimensions 5060*da0073e9SAndroid Build Coastguard Worker check_indexing('[None, 0]', consec((3, 3))) 5061*da0073e9SAndroid Build Coastguard Worker check_indexing('[1, None]', consec((3, 3), 10)) 5062*da0073e9SAndroid Build Coastguard Worker check_indexing('[None, None, 2]', consec((3, 3), 19)) 5063*da0073e9SAndroid Build Coastguard Worker check_indexing('[None, 2, None]', consec((3,))) 5064*da0073e9SAndroid Build Coastguard Worker check_indexing('[0:2, None]', consec((3, 3, 3))) 5065*da0073e9SAndroid Build Coastguard Worker check_indexing('[None, 1:-1]', consec((3, 3, 3))) 5066*da0073e9SAndroid Build Coastguard Worker check_indexing('[None, -3:-1, None]', consec((6, 3))) 5067*da0073e9SAndroid Build Coastguard Worker check_indexing('[-1, None, 2:, None, 1:2]', consec((3, 3, 3, 3))) 5068*da0073e9SAndroid Build Coastguard Worker check_indexing('[None, -1, None, 2:, None, 1:2, None]', consec((3, 3, 3, 3))) 5069*da0073e9SAndroid Build Coastguard Worker 5070*da0073e9SAndroid Build Coastguard Worker # dynamic expression usage 5071*da0073e9SAndroid Build Coastguard Worker check_dynamic_indexing("[i + j]", consec((3, 3)), 0, 1) 5072*da0073e9SAndroid Build Coastguard Worker check_dynamic_indexing("[i:j, i]", consec((3, 3, 2)), 0, 2) 5073*da0073e9SAndroid Build Coastguard Worker 5074*da0073e9SAndroid Build Coastguard Worker # positive striding 5075*da0073e9SAndroid Build Coastguard Worker check_indexing_list_int('[0]', consec_list(6)) 5076*da0073e9SAndroid Build Coastguard Worker check_indexing_list_int('[1]', consec_list(7)) 5077*da0073e9SAndroid Build Coastguard Worker check_indexing_list_int('[2]', consec_list(8)) 5078*da0073e9SAndroid Build Coastguard Worker check_indexing_list_int('[2]', consec_list(9)) 5079*da0073e9SAndroid Build Coastguard Worker check_indexing_list_int('[-1]', consec_list(10)) 5080*da0073e9SAndroid Build Coastguard Worker check_indexing_list_int('[0:2]', consec_list(11)) 5081*da0073e9SAndroid Build Coastguard Worker check_indexing_list_int('[1:-1]', consec_list(12)) 5082*da0073e9SAndroid Build Coastguard Worker check_indexing_list_int('[-3:-1]', consec_list(13)) 5083*da0073e9SAndroid Build Coastguard Worker check_indexing_list_int('[1:]', consec_list(15)) 5084*da0073e9SAndroid Build Coastguard Worker check_indexing_list_int('[:1]', consec_list(16)) 5085*da0073e9SAndroid Build Coastguard Worker check_indexing_list_int('[:]', consec_list(17)) 5086*da0073e9SAndroid Build Coastguard Worker check_indexing_list_int('[::]', consec_list(0)) 5087*da0073e9SAndroid Build Coastguard Worker check_indexing_list_int('[1000::]', consec_list(0)) 5088*da0073e9SAndroid Build Coastguard Worker check_indexing_list_int('[:1000:]', consec_list(0)) 5089*da0073e9SAndroid Build Coastguard Worker 5090*da0073e9SAndroid Build Coastguard Worker # negative striding 5091*da0073e9SAndroid Build Coastguard Worker check_indexing_list_int('[::-1]', consec_list(7)) 5092*da0073e9SAndroid Build Coastguard Worker check_indexing_list_int('[:3:-1]', consec_list(7)) 5093*da0073e9SAndroid Build Coastguard Worker check_indexing_list_int('[3::-1]', consec_list(7)) 5094*da0073e9SAndroid Build Coastguard Worker check_indexing_list_int('[1000::-1]', consec_list(7)) 5095*da0073e9SAndroid Build Coastguard Worker check_indexing_list_int('[3:0:-1]', consec_list(7)) 5096*da0073e9SAndroid Build Coastguard Worker check_indexing_list_int('[3:-1000:-1]', consec_list(7)) 5097*da0073e9SAndroid Build Coastguard Worker check_indexing_list_int('[0:0:-1]', consec_list(7)) 5098*da0073e9SAndroid Build Coastguard Worker check_indexing_list_int('[0:-1000:-1]', consec_list(7)) 5099*da0073e9SAndroid Build Coastguard Worker 5100*da0073e9SAndroid Build Coastguard Worker # only step is specified 5101*da0073e9SAndroid Build Coastguard Worker check_indexing_list_int('[::-1]', consec_list(0)) 5102*da0073e9SAndroid Build Coastguard Worker check_indexing_list_int('[::-1]', consec_list(7)) 5103*da0073e9SAndroid Build Coastguard Worker check_indexing_list_int('[::-2]', consec_list(7)) 5104*da0073e9SAndroid Build Coastguard Worker check_indexing_list_int('[::2]', consec_list(7)) 5105*da0073e9SAndroid Build Coastguard Worker check_indexing_list_int('[::42]', consec_list(7)) 5106*da0073e9SAndroid Build Coastguard Worker check_indexing_list_int('[::-42]', consec_list(7)) 5107*da0073e9SAndroid Build Coastguard Worker check_indexing_list_int('[::42]', consec_list(0)) 5108*da0073e9SAndroid Build Coastguard Worker check_indexing_list_int('[::-42]', consec_list(0)) 5109*da0073e9SAndroid Build Coastguard Worker check_indexing_list_int('[::9223372036854775807]', consec_list(42)) 5110*da0073e9SAndroid Build Coastguard Worker check_indexing_list_int('[::-9223372036854775807]', consec_list(42)) 5111*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "out of bounds"): 5112*da0073e9SAndroid Build Coastguard Worker check_indexing_list_int('[::-9223372036854775808]', consec_list(42)) 5113*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "should have non-zero step"): 5114*da0073e9SAndroid Build Coastguard Worker check_indexing_list_int('[::0]', consec_list(42)) 5115*da0073e9SAndroid Build Coastguard Worker 5116*da0073e9SAndroid Build Coastguard Worker # striding strings 5117*da0073e9SAndroid Build Coastguard Worker check_indexing_str('[0]', random_string(6)) 5118*da0073e9SAndroid Build Coastguard Worker check_indexing_str('[1]', random_string(7)) 5119*da0073e9SAndroid Build Coastguard Worker check_indexing_str('[2]', random_string(8)) 5120*da0073e9SAndroid Build Coastguard Worker check_indexing_str('[2]', random_string(9)) 5121*da0073e9SAndroid Build Coastguard Worker check_indexing_str('[-1]', random_string(10)) 5122*da0073e9SAndroid Build Coastguard Worker check_indexing_str('[0:2]', random_string(11)) 5123*da0073e9SAndroid Build Coastguard Worker check_indexing_str('[1:-1]', random_string(12)) 5124*da0073e9SAndroid Build Coastguard Worker check_indexing_str('[-3:-1]', random_string(13)) 5125*da0073e9SAndroid Build Coastguard Worker check_indexing_str('[1:]', random_string(15)) 5126*da0073e9SAndroid Build Coastguard Worker check_indexing_str('[:1]', random_string(16)) 5127*da0073e9SAndroid Build Coastguard Worker check_indexing_str('[:]', random_string(17)) 5128*da0073e9SAndroid Build Coastguard Worker check_indexing_str('[::]', random_string(0)) 5129*da0073e9SAndroid Build Coastguard Worker check_indexing_str('[1000::]', random_string(0)) 5130*da0073e9SAndroid Build Coastguard Worker check_indexing_str('[:1000:]', random_string(0)) 5131*da0073e9SAndroid Build Coastguard Worker 5132*da0073e9SAndroid Build Coastguard Worker check_indexing_str('[::-1]', random_string(7)) 5133*da0073e9SAndroid Build Coastguard Worker check_indexing_str('[:3:-1]', random_string(7)) 5134*da0073e9SAndroid Build Coastguard Worker check_indexing_str('[3::-1]', random_string(7)) 5135*da0073e9SAndroid Build Coastguard Worker check_indexing_str('[1000::-1]', random_string(7)) 5136*da0073e9SAndroid Build Coastguard Worker check_indexing_str('[3:0:-1]', random_string(7)) 5137*da0073e9SAndroid Build Coastguard Worker check_indexing_str('[3:-1000:-1]', random_string(7)) 5138*da0073e9SAndroid Build Coastguard Worker check_indexing_str('[0:0:-1]', random_string(7)) 5139*da0073e9SAndroid Build Coastguard Worker check_indexing_str('[0:-1000:-1]', random_string(7)) 5140*da0073e9SAndroid Build Coastguard Worker 5141*da0073e9SAndroid Build Coastguard Worker check_indexing_str('[::-1]', random_string(0)) 5142*da0073e9SAndroid Build Coastguard Worker check_indexing_str('[::-1]', random_string(7)) 5143*da0073e9SAndroid Build Coastguard Worker check_indexing_str('[::-2]', random_string(7)) 5144*da0073e9SAndroid Build Coastguard Worker check_indexing_str('[::2]', random_string(7)) 5145*da0073e9SAndroid Build Coastguard Worker check_indexing_str('[::42]', random_string(7)) 5146*da0073e9SAndroid Build Coastguard Worker check_indexing_str('[::-42]', random_string(7)) 5147*da0073e9SAndroid Build Coastguard Worker check_indexing_str('[::42]', random_string(0)) 5148*da0073e9SAndroid Build Coastguard Worker check_indexing_str('[::-42]', random_string(0)) 5149*da0073e9SAndroid Build Coastguard Worker check_indexing_str('[::9223372036854775807]', random_string(42)) 5150*da0073e9SAndroid Build Coastguard Worker check_indexing_str('[::-9223372036854775807]', random_string(42)) 5151*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "out of bounds"): 5152*da0073e9SAndroid Build Coastguard Worker check_indexing_str('[::-9223372036854775808]', random_string(42)) 5153*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "should have non-zero step"): 5154*da0073e9SAndroid Build Coastguard Worker check_indexing_str('[::0]', random_string(42)) 5155*da0073e9SAndroid Build Coastguard Worker 5156*da0073e9SAndroid Build Coastguard Worker def test_module_copy_with_attributes(self): 5157*da0073e9SAndroid Build Coastguard Worker class Vocabulary(torch.jit.ScriptModule): 5158*da0073e9SAndroid Build Coastguard Worker def __init__(self, vocab_list): 5159*da0073e9SAndroid Build Coastguard Worker super().__init__() 5160*da0073e9SAndroid Build Coastguard Worker self._vocab = torch.jit.Attribute(vocab_list, List[str]) 5161*da0073e9SAndroid Build Coastguard Worker self.some_idx = torch.jit.Attribute(2, int) 5162*da0073e9SAndroid Build Coastguard Worker self.idx = torch.jit.Attribute( 5163*da0073e9SAndroid Build Coastguard Worker {word: i for i, word in enumerate(vocab_list)}, Dict[str, int] 5164*da0073e9SAndroid Build Coastguard Worker ) 5165*da0073e9SAndroid Build Coastguard Worker 5166*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 5167*da0073e9SAndroid Build Coastguard Worker def lookup_indices_1d(self, values): 5168*da0073e9SAndroid Build Coastguard Worker # type: (List[str]) -> List[int] 5169*da0073e9SAndroid Build Coastguard Worker result = torch.jit.annotate(List[int], []) 5170*da0073e9SAndroid Build Coastguard Worker # Direct list iteration not supported 5171*da0073e9SAndroid Build Coastguard Worker for i in range(len(values)): 5172*da0073e9SAndroid Build Coastguard Worker value = values[i] 5173*da0073e9SAndroid Build Coastguard Worker result.append(self.idx.get(value, self.some_idx)) 5174*da0073e9SAndroid Build Coastguard Worker return result 5175*da0073e9SAndroid Build Coastguard Worker 5176*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 5177*da0073e9SAndroid Build Coastguard Worker def forward(self, values): 5178*da0073e9SAndroid Build Coastguard Worker # type: (List[List[str]]) -> List[List[int]] 5179*da0073e9SAndroid Build Coastguard Worker result = torch.jit.annotate(List[List[int]], []) 5180*da0073e9SAndroid Build Coastguard Worker # Direct list iteration not supported 5181*da0073e9SAndroid Build Coastguard Worker for i in range(len(values)): 5182*da0073e9SAndroid Build Coastguard Worker result.append(self.lookup_indices_1d(values[i])) 5183*da0073e9SAndroid Build Coastguard Worker return result 5184*da0073e9SAndroid Build Coastguard Worker 5185*da0073e9SAndroid Build Coastguard Worker v = Vocabulary(list('uabcdefg')) 5186*da0073e9SAndroid Build Coastguard Worker v.__copy__() 5187*da0073e9SAndroid Build Coastguard Worker 5188*da0073e9SAndroid Build Coastguard Worker def test_tuple_to_opt_list(self): 5189*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 5190*da0073e9SAndroid Build Coastguard Worker def foo(x): 5191*da0073e9SAndroid Build Coastguard Worker # type: (Optional[List[int]]) -> int 5192*da0073e9SAndroid Build Coastguard Worker return 1 5193*da0073e9SAndroid Build Coastguard Worker 5194*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 5195*da0073e9SAndroid Build Coastguard Worker def tuple_call(): 5196*da0073e9SAndroid Build Coastguard Worker return foo((1, 2)) 5197*da0073e9SAndroid Build Coastguard Worker 5198*da0073e9SAndroid Build Coastguard Worker def test_keyword(self): 5199*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 5200*da0073e9SAndroid Build Coastguard Worker def func(x): 5201*da0073e9SAndroid Build Coastguard Worker return torch.sum(x, dim=0) 5202*da0073e9SAndroid Build Coastguard Worker 5203*da0073e9SAndroid Build Coastguard Worker x = torch.rand(10, dtype=torch.float, requires_grad=True) 5204*da0073e9SAndroid Build Coastguard Worker y = func(x) 5205*da0073e9SAndroid Build Coastguard Worker y2 = torch.sum(x, dim=0) 5206*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, y2) 5207*da0073e9SAndroid Build Coastguard Worker 5208*da0073e9SAndroid Build Coastguard Worker def test_constant_pooling_none(self): 5209*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 5210*da0073e9SAndroid Build Coastguard Worker def typed_nones(a=None, b=None, c=None): 5211*da0073e9SAndroid Build Coastguard Worker # type: (Optional[int], Optional[bool], Optional[Tensor]) -> Tuple[Optional[int], Optional[bool], Optional[Tensor]] 5212*da0073e9SAndroid Build Coastguard Worker return a, b, c 5213*da0073e9SAndroid Build Coastguard Worker 5214*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 5215*da0073e9SAndroid Build Coastguard Worker def test(a): 5216*da0073e9SAndroid Build Coastguard Worker # type: (bool) -> None 5217*da0073e9SAndroid Build Coastguard Worker if a: 5218*da0073e9SAndroid Build Coastguard Worker print(typed_nones()) 5219*da0073e9SAndroid Build Coastguard Worker else: 5220*da0073e9SAndroid Build Coastguard Worker print(typed_nones()) 5221*da0073e9SAndroid Build Coastguard Worker 5222*da0073e9SAndroid Build Coastguard Worker graph_str = str(test.graph) 5223*da0073e9SAndroid Build Coastguard Worker self.assertTrue(graph_str.count("NoneType = prim::Constant") == 1) 5224*da0073e9SAndroid Build Coastguard Worker 5225*da0073e9SAndroid Build Coastguard Worker def test_constant_pooling_same_identity(self): 5226*da0073e9SAndroid Build Coastguard Worker def foo(): 5227*da0073e9SAndroid Build Coastguard Worker a = torch.tensor([4]) 5228*da0073e9SAndroid Build Coastguard Worker b = (a,) 5229*da0073e9SAndroid Build Coastguard Worker index = len(a) - 1 5230*da0073e9SAndroid Build Coastguard Worker c = b[index] 5231*da0073e9SAndroid Build Coastguard Worker d = b[index] 5232*da0073e9SAndroid Build Coastguard Worker return c, d 5233*da0073e9SAndroid Build Coastguard Worker 5234*da0073e9SAndroid Build Coastguard Worker foo_script = torch.jit.script(foo) 5235*da0073e9SAndroid Build Coastguard Worker self.run_pass('constant_propagation', foo_script.graph) 5236*da0073e9SAndroid Build Coastguard Worker self.run_pass('constant_pooling', foo_script.graph) 5237*da0073e9SAndroid Build Coastguard Worker # even though the c & d escape scope, we are still able 5238*da0073e9SAndroid Build Coastguard Worker # pool them into one constant because they are the same object 5239*da0073e9SAndroid Build Coastguard Worker FileCheck().check_count("prim::Constant", 1, exactly=True).run(foo_script.graph) 5240*da0073e9SAndroid Build Coastguard Worker self.assertEqual(foo(), foo_script()) 5241*da0073e9SAndroid Build Coastguard Worker 5242*da0073e9SAndroid Build Coastguard Worker def test_constant_pooling_introduce_aliasing(self): 5243*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 5244*da0073e9SAndroid Build Coastguard Worker def foo(): 5245*da0073e9SAndroid Build Coastguard Worker a = torch.tensor(1) 5246*da0073e9SAndroid Build Coastguard Worker b = torch.tensor(1) 5247*da0073e9SAndroid Build Coastguard Worker return a, b 5248*da0073e9SAndroid Build Coastguard Worker 5249*da0073e9SAndroid Build Coastguard Worker self.run_pass('constant_propagation', foo.graph) 5250*da0073e9SAndroid Build Coastguard Worker self.run_pass('constant_pooling', foo.graph) 5251*da0073e9SAndroid Build Coastguard Worker # dont pool constants bc it would introduce observable alias relationship changing 5252*da0073e9SAndroid Build Coastguard Worker a, b = foo() 5253*da0073e9SAndroid Build Coastguard Worker self.assertIsNot(a, b) 5254*da0073e9SAndroid Build Coastguard Worker 5255*da0073e9SAndroid Build Coastguard Worker def test_literal(self): 5256*da0073e9SAndroid Build Coastguard Worker def func1(a, b): 5257*da0073e9SAndroid Build Coastguard Worker c = a, b 5258*da0073e9SAndroid Build Coastguard Worker d, e = c 5259*da0073e9SAndroid Build Coastguard Worker return d + e 5260*da0073e9SAndroid Build Coastguard Worker 5261*da0073e9SAndroid Build Coastguard Worker def func2(a, b): 5262*da0073e9SAndroid Build Coastguard Worker c = a, (a, b) 5263*da0073e9SAndroid Build Coastguard Worker d, e = c 5264*da0073e9SAndroid Build Coastguard Worker f, g = e 5265*da0073e9SAndroid Build Coastguard Worker return d + f + g 5266*da0073e9SAndroid Build Coastguard Worker 5267*da0073e9SAndroid Build Coastguard Worker def func3(a, b): 5268*da0073e9SAndroid Build Coastguard Worker # type: (float, float) -> float 5269*da0073e9SAndroid Build Coastguard Worker c = 0., (0., 0.) 5270*da0073e9SAndroid Build Coastguard Worker x = True 5271*da0073e9SAndroid Build Coastguard Worker while x: 5272*da0073e9SAndroid Build Coastguard Worker x = False 5273*da0073e9SAndroid Build Coastguard Worker c = a, (a, b) 5274*da0073e9SAndroid Build Coastguard Worker d, e = c 5275*da0073e9SAndroid Build Coastguard Worker f, g = e 5276*da0073e9SAndroid Build Coastguard Worker return d + f + g 5277*da0073e9SAndroid Build Coastguard Worker 5278*da0073e9SAndroid Build Coastguard Worker a = torch.rand(1, requires_grad=True) 5279*da0073e9SAndroid Build Coastguard Worker b = torch.rand(1, requires_grad=True) 5280*da0073e9SAndroid Build Coastguard Worker self.checkScript(func1, (a, b), optimize=True) 5281*da0073e9SAndroid Build Coastguard Worker self.checkScript(func2, (a, b), optimize=True) 5282*da0073e9SAndroid Build Coastguard Worker self.checkScript(func3, (a.item(), b.item()), optimize=True) 5283*da0073e9SAndroid Build Coastguard Worker 5284*da0073e9SAndroid Build Coastguard Worker def test_expand(self): 5285*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 5286*da0073e9SAndroid Build Coastguard Worker def func(x, y): 5287*da0073e9SAndroid Build Coastguard Worker return x + y 5288*da0073e9SAndroid Build Coastguard Worker 5289*da0073e9SAndroid Build Coastguard Worker x = torch.rand(2, 3, dtype=torch.float, requires_grad=True) 5290*da0073e9SAndroid Build Coastguard Worker y = torch.rand(3, dtype=torch.float, requires_grad=True) 5291*da0073e9SAndroid Build Coastguard Worker out = func(x, y) 5292*da0073e9SAndroid Build Coastguard Worker self.assertEqual(func(x, y), x + y) 5293*da0073e9SAndroid Build Coastguard Worker 5294*da0073e9SAndroid Build Coastguard Worker grad = torch.randn(2, 3, dtype=torch.float) 5295*da0073e9SAndroid Build Coastguard Worker out.backward(grad) 5296*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, grad) 5297*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y.grad, grad.sum(dim=0)) 5298*da0073e9SAndroid Build Coastguard Worker 5299*da0073e9SAndroid Build Coastguard Worker def test_sum(self): 5300*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 5301*da0073e9SAndroid Build Coastguard Worker def func(x): 5302*da0073e9SAndroid Build Coastguard Worker return x.sum(dim=[4]) 5303*da0073e9SAndroid Build Coastguard Worker 5304*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 5305*da0073e9SAndroid Build Coastguard Worker def func2(x): 5306*da0073e9SAndroid Build Coastguard Worker return x.sum(dim=4) 5307*da0073e9SAndroid Build Coastguard Worker 5308*da0073e9SAndroid Build Coastguard Worker # test that shape analysis is written correctly for sum with OptionalIntArrayRef[1] dim argument 5309*da0073e9SAndroid Build Coastguard Worker self.run_pass('constant_propagation', func.graph) 5310*da0073e9SAndroid Build Coastguard Worker self.run_pass('constant_propagation', func2.graph) 5311*da0073e9SAndroid Build Coastguard Worker g = _propagate_shapes(func.graph, (torch.zeros(1, 1, 1, 1, 4),), False) 5312*da0073e9SAndroid Build Coastguard Worker g2 = _propagate_shapes(func2.graph, (torch.zeros(1, 1, 1, 1, 4),), False) 5313*da0073e9SAndroid Build Coastguard Worker 5314*da0073e9SAndroid Build Coastguard Worker def test_cat(self): 5315*da0073e9SAndroid Build Coastguard Worker with enable_profiling_mode_for_profiling_tests(): 5316*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 5317*da0073e9SAndroid Build Coastguard Worker def func(x): 5318*da0073e9SAndroid Build Coastguard Worker return torch.cat((x, x), dim=0) 5319*da0073e9SAndroid Build Coastguard Worker 5320*da0073e9SAndroid Build Coastguard Worker x = torch.rand(10, dtype=torch.float, requires_grad=True) 5321*da0073e9SAndroid Build Coastguard Worker self.assertEqual(func(x, profile_and_replay=True), torch.cat((x, x), dim=0)) 5322*da0073e9SAndroid Build Coastguard Worker 5323*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 5324*da0073e9SAndroid Build Coastguard Worker def func2(x, y): 5325*da0073e9SAndroid Build Coastguard Worker return torch.cat((x, x), y) 5326*da0073e9SAndroid Build Coastguard Worker 5327*da0073e9SAndroid Build Coastguard Worker with disable_autodiff_subgraph_inlining(): 5328*da0073e9SAndroid Build Coastguard Worker for sizes in ((2, 2), (0, 2)): 5329*da0073e9SAndroid Build Coastguard Worker x = torch.rand(sizes).requires_grad_() 5330*da0073e9SAndroid Build Coastguard Worker y = torch.tensor(1) 5331*da0073e9SAndroid Build Coastguard Worker 5332*da0073e9SAndroid Build Coastguard Worker output = func2(x, y, profile_and_replay=True) 5333*da0073e9SAndroid Build Coastguard Worker output_ref = torch.cat((x, x), y) 5334*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output, output_ref) 5335*da0073e9SAndroid Build Coastguard Worker 5336*da0073e9SAndroid Build Coastguard Worker if GRAPH_EXECUTOR != ProfilingMode.SIMPLE: 5337*da0073e9SAndroid Build Coastguard Worker self.assertAutodiffNode(func2.graph_for(x, y), True, ['aten::cat'], []) 5338*da0073e9SAndroid Build Coastguard Worker 5339*da0073e9SAndroid Build Coastguard Worker grad = torch.autograd.grad(output.sum(), x) 5340*da0073e9SAndroid Build Coastguard Worker grad_ref = torch.autograd.grad(output_ref.sum(), x) 5341*da0073e9SAndroid Build Coastguard Worker self.assertEqual(grad, grad_ref) 5342*da0073e9SAndroid Build Coastguard Worker 5343*da0073e9SAndroid Build Coastguard Worker def test_cat_lifts(self): 5344*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 5345*da0073e9SAndroid Build Coastguard Worker def foo(x): 5346*da0073e9SAndroid Build Coastguard Worker return torch.cat([x, x], dim=1) 5347*da0073e9SAndroid Build Coastguard Worker 5348*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 5349*da0073e9SAndroid Build Coastguard Worker def foo2(x): 5350*da0073e9SAndroid Build Coastguard Worker return torch.cat([], dim=1) 5351*da0073e9SAndroid Build Coastguard Worker 5352*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 5353*da0073e9SAndroid Build Coastguard Worker def foo3(x): 5354*da0073e9SAndroid Build Coastguard Worker return torch.cat([x], dim=1) 5355*da0073e9SAndroid Build Coastguard Worker 5356*da0073e9SAndroid Build Coastguard Worker for g in [foo.graph, foo2.graph, foo3.graph]: 5357*da0073e9SAndroid Build Coastguard Worker FileCheck().check("int =").check("ListConstruct").check("aten::cat").run(str(g)) 5358*da0073e9SAndroid Build Coastguard Worker 5359*da0073e9SAndroid Build Coastguard Worker def test_stack(self): 5360*da0073e9SAndroid Build Coastguard Worker with enable_profiling_mode_for_profiling_tests(): 5361*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 5362*da0073e9SAndroid Build Coastguard Worker def func(x): 5363*da0073e9SAndroid Build Coastguard Worker return torch.stack((x, x), dim=1) 5364*da0073e9SAndroid Build Coastguard Worker x = torch.rand(10, 10) 5365*da0073e9SAndroid Build Coastguard Worker self.assertEqual(func(x, profile_and_replay=True), torch.stack((x, x), dim=1)) 5366*da0073e9SAndroid Build Coastguard Worker 5367*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 5368*da0073e9SAndroid Build Coastguard Worker def func2(x, y): 5369*da0073e9SAndroid Build Coastguard Worker return torch.stack((x, y), dim=0) 5370*da0073e9SAndroid Build Coastguard Worker 5371*da0073e9SAndroid Build Coastguard Worker with disable_autodiff_subgraph_inlining(): 5372*da0073e9SAndroid Build Coastguard Worker x = torch.randn([2, 2]).requires_grad_() 5373*da0073e9SAndroid Build Coastguard Worker y = torch.randn([2, 2]).requires_grad_() 5374*da0073e9SAndroid Build Coastguard Worker 5375*da0073e9SAndroid Build Coastguard Worker output = func2(x, y, profile_and_replay=True) 5376*da0073e9SAndroid Build Coastguard Worker output_ref = torch.stack((x, y), 0) 5377*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output, output_ref) 5378*da0073e9SAndroid Build Coastguard Worker if GRAPH_EXECUTOR != ProfilingMode.SIMPLE: 5379*da0073e9SAndroid Build Coastguard Worker self.assertAutodiffNode(func2.graph_for(x, y), True, ['aten::stack'], []) 5380*da0073e9SAndroid Build Coastguard Worker 5381*da0073e9SAndroid Build Coastguard Worker grads = torch.autograd.grad(output.sum(), (x, y)) 5382*da0073e9SAndroid Build Coastguard Worker grads_ref = torch.autograd.grad(output_ref.sum(), (x, y)) 5383*da0073e9SAndroid Build Coastguard Worker self.assertEqual(grads, grads_ref) 5384*da0073e9SAndroid Build Coastguard Worker 5385*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, 5386*da0073e9SAndroid Build Coastguard Worker "Profiling executor will be using different heuristics for constructing differentiable graphs") 5387*da0073e9SAndroid Build Coastguard Worker def test_unbind(self): 5388*da0073e9SAndroid Build Coastguard Worker with enable_profiling_mode_for_profiling_tests(): 5389*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 5390*da0073e9SAndroid Build Coastguard Worker def func(x, y): 5391*da0073e9SAndroid Build Coastguard Worker # type: (Tensor, int) -> List[Tensor] 5392*da0073e9SAndroid Build Coastguard Worker return torch.unbind(x, y) 5393*da0073e9SAndroid Build Coastguard Worker 5394*da0073e9SAndroid Build Coastguard Worker with disable_autodiff_subgraph_inlining(): 5395*da0073e9SAndroid Build Coastguard Worker x = torch.rand([2, 2]).requires_grad_() 5396*da0073e9SAndroid Build Coastguard Worker y = 0 5397*da0073e9SAndroid Build Coastguard Worker outputs = func(x, y, profile_and_replay=True) 5398*da0073e9SAndroid Build Coastguard Worker outputs_ref = torch.unbind(x, dim=y) 5399*da0073e9SAndroid Build Coastguard Worker self.assertEqual(outputs, outputs_ref) 5400*da0073e9SAndroid Build Coastguard Worker self.assertAutodiffNode(func.graph_for(x, y), True, [], []) 5401*da0073e9SAndroid Build Coastguard Worker 5402*da0073e9SAndroid Build Coastguard Worker grad = torch.autograd.grad(_sum_of_list(outputs), x) 5403*da0073e9SAndroid Build Coastguard Worker grad_ref = torch.autograd.grad(_sum_of_list(outputs_ref), x) 5404*da0073e9SAndroid Build Coastguard Worker self.assertEqual(grad, grad_ref) 5405*da0073e9SAndroid Build Coastguard Worker 5406*da0073e9SAndroid Build Coastguard Worker 5407*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(GRAPH_EXECUTOR == ProfilingMode.PROFILING, 5408*da0073e9SAndroid Build Coastguard Worker "Profiling executor fails to recognize that tensors in a list require gradients") 5409*da0073e9SAndroid Build Coastguard Worker def test_meshgrid(self): 5410*da0073e9SAndroid Build Coastguard Worker with enable_profiling_mode_for_profiling_tests(): 5411*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 5412*da0073e9SAndroid Build Coastguard Worker def func(a): 5413*da0073e9SAndroid Build Coastguard Worker # type: (List[Tensor]) -> List[Tensor] 5414*da0073e9SAndroid Build Coastguard Worker return torch.meshgrid(a) 5415*da0073e9SAndroid Build Coastguard Worker with disable_autodiff_subgraph_inlining(): 5416*da0073e9SAndroid Build Coastguard Worker a = torch.tensor([1.0, 2, 3]).requires_grad_() 5417*da0073e9SAndroid Build Coastguard Worker b = torch.tensor([1.0, 2, 3, 4]).requires_grad_() 5418*da0073e9SAndroid Build Coastguard Worker inputs = [a, b] 5419*da0073e9SAndroid Build Coastguard Worker 5420*da0073e9SAndroid Build Coastguard Worker outputs_ref = torch.meshgrid(inputs) 5421*da0073e9SAndroid Build Coastguard Worker outputs = func(inputs, profile_and_replay=True) 5422*da0073e9SAndroid Build Coastguard Worker self.assertEqual(outputs, outputs_ref) 5423*da0073e9SAndroid Build Coastguard Worker 5424*da0073e9SAndroid Build Coastguard Worker if GRAPH_EXECUTOR != ProfilingMode.SIMPLE: 5425*da0073e9SAndroid Build Coastguard Worker self.assertAutodiffNode(func.graph_for(inputs), True, [], []) 5426*da0073e9SAndroid Build Coastguard Worker 5427*da0073e9SAndroid Build Coastguard Worker grads = torch.autograd.grad(_sum_of_list(outputs), inputs) 5428*da0073e9SAndroid Build Coastguard Worker grads_ref = torch.autograd.grad(_sum_of_list(outputs_ref), inputs) 5429*da0073e9SAndroid Build Coastguard Worker self.assertEqual(grads, grads_ref) 5430*da0073e9SAndroid Build Coastguard Worker 5431*da0073e9SAndroid Build Coastguard Worker def test_tensor_len(self): 5432*da0073e9SAndroid Build Coastguard Worker def func(x): 5433*da0073e9SAndroid Build Coastguard Worker return len(x) 5434*da0073e9SAndroid Build Coastguard Worker 5435*da0073e9SAndroid Build Coastguard Worker self.checkScript(func, [torch.ones(4, 5, 6)]) 5436*da0073e9SAndroid Build Coastguard Worker 5437*da0073e9SAndroid Build Coastguard Worker def test_func_call(self): 5438*da0073e9SAndroid Build Coastguard Worker def add(a, b): 5439*da0073e9SAndroid Build Coastguard Worker return a + b 5440*da0073e9SAndroid Build Coastguard Worker 5441*da0073e9SAndroid Build Coastguard Worker def mul(a, x): 5442*da0073e9SAndroid Build Coastguard Worker return a * x 5443*da0073e9SAndroid Build Coastguard Worker 5444*da0073e9SAndroid Build Coastguard Worker def func(alpha, beta, x, y): 5445*da0073e9SAndroid Build Coastguard Worker return add(mul(alpha, x), mul(beta, y)) 5446*da0073e9SAndroid Build Coastguard Worker 5447*da0073e9SAndroid Build Coastguard Worker alpha = torch.rand(1, dtype=torch.float, requires_grad=True) 5448*da0073e9SAndroid Build Coastguard Worker beta = torch.rand(1, dtype=torch.float, requires_grad=True) 5449*da0073e9SAndroid Build Coastguard Worker x = torch.rand(3, dtype=torch.float, requires_grad=True) 5450*da0073e9SAndroid Build Coastguard Worker y = torch.rand(3, dtype=torch.float, requires_grad=True) 5451*da0073e9SAndroid Build Coastguard Worker 5452*da0073e9SAndroid Build Coastguard Worker # NOTE: cannot optimize yet because broadcasts are not inserted before the fuser runs 5453*da0073e9SAndroid Build Coastguard Worker self.checkScript(func, [alpha, beta, x, y], optimize=False) 5454*da0073e9SAndroid Build Coastguard Worker 5455*da0073e9SAndroid Build Coastguard Worker @unittest.skip("bailouts are being deprecated") 5456*da0073e9SAndroid Build Coastguard Worker def test_profiling_graph_executor(self): 5457*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 5458*da0073e9SAndroid Build Coastguard Worker def def_in_one_branch(x, z): 5459*da0073e9SAndroid Build Coastguard Worker # type: (Tensor, bool) -> float 5460*da0073e9SAndroid Build Coastguard Worker y = x 5461*da0073e9SAndroid Build Coastguard Worker if z is False: 5462*da0073e9SAndroid Build Coastguard Worker y = x + 1 5463*da0073e9SAndroid Build Coastguard Worker 5464*da0073e9SAndroid Build Coastguard Worker return y.sum() 5465*da0073e9SAndroid Build Coastguard Worker 5466*da0073e9SAndroid Build Coastguard Worker a = torch.rand(2, 3) 5467*da0073e9SAndroid Build Coastguard Worker 5468*da0073e9SAndroid Build Coastguard Worker with enable_profiling_mode_for_profiling_tests(): 5469*da0073e9SAndroid Build Coastguard Worker # check prim::profile are inserted 5470*da0073e9SAndroid Build Coastguard Worker profiled_graph_str = str(def_in_one_branch.graph_for(a, True)) 5471*da0073e9SAndroid Build Coastguard Worker FileCheck().check_count("prim::profile", 4).run(profiled_graph_str) 5472*da0073e9SAndroid Build Coastguard Worker # this call is optimized for 5473*da0073e9SAndroid Build Coastguard Worker # the given shape of (2, 3) 5474*da0073e9SAndroid Build Coastguard Worker def_in_one_branch(a, False) 5475*da0073e9SAndroid Build Coastguard Worker # change shape to (3) 5476*da0073e9SAndroid Build Coastguard Worker # so we go down a bailout path 5477*da0073e9SAndroid Build Coastguard Worker a = torch.ones(3) 5478*da0073e9SAndroid Build Coastguard Worker # check prim::BailOuts are inserted 5479*da0073e9SAndroid Build Coastguard Worker bailout_graph_str = str(def_in_one_branch.graph_for(a, True)) 5480*da0073e9SAndroid Build Coastguard Worker FileCheck().check_count("prim::BailOut", 3).run(bailout_graph_str) 5481*da0073e9SAndroid Build Coastguard Worker # this triggers all 3 bailouts 5482*da0073e9SAndroid Build Coastguard Worker self.assertEqual(def_in_one_branch(a, False), 6.0) 5483*da0073e9SAndroid Build Coastguard Worker # this triggers 2 bailouts 5484*da0073e9SAndroid Build Coastguard Worker self.assertEqual(def_in_one_branch(a, True), 3.0) 5485*da0073e9SAndroid Build Coastguard Worker 5486*da0073e9SAndroid Build Coastguard Worker @unittest.skip("bailouts are being deprecated") 5487*da0073e9SAndroid Build Coastguard Worker def test_maxpool_guard_elimination(self): 5488*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 5489*da0073e9SAndroid Build Coastguard Worker def my_maxpool(x): 5490*da0073e9SAndroid Build Coastguard Worker return F.max_pool1d(x, kernel_size=[1]) + torch.ones([32, 32, 32]) 5491*da0073e9SAndroid Build Coastguard Worker 5492*da0073e9SAndroid Build Coastguard Worker a = torch.rand(32, 32, 32) 5493*da0073e9SAndroid Build Coastguard Worker 5494*da0073e9SAndroid Build Coastguard Worker with enable_profiling_mode_for_profiling_tests(): 5495*da0073e9SAndroid Build Coastguard Worker my_maxpool(a) 5496*da0073e9SAndroid Build Coastguard Worker bailout_graph_str = str(my_maxpool.graph_for(a)) 5497*da0073e9SAndroid Build Coastguard Worker FileCheck().check_count("prim::BailOut", 1).run(bailout_graph_str) 5498*da0073e9SAndroid Build Coastguard Worker 5499*da0073e9SAndroid Build Coastguard Worker @unittest.skip("bailouts are being deprecated") 5500*da0073e9SAndroid Build Coastguard Worker def test_slice_guard_elimination(self): 5501*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 5502*da0073e9SAndroid Build Coastguard Worker def my_slice(x): 5503*da0073e9SAndroid Build Coastguard Worker return x[0:16:2] + x[0:16:2] 5504*da0073e9SAndroid Build Coastguard Worker 5505*da0073e9SAndroid Build Coastguard Worker a = torch.rand(32, 4) 5506*da0073e9SAndroid Build Coastguard Worker 5507*da0073e9SAndroid Build Coastguard Worker with enable_profiling_mode_for_profiling_tests(): 5508*da0073e9SAndroid Build Coastguard Worker my_slice(a) 5509*da0073e9SAndroid Build Coastguard Worker bailout_graph_str = str(my_slice.graph_for(a)) 5510*da0073e9SAndroid Build Coastguard Worker FileCheck().check_count("prim::BailOut", 1).run(bailout_graph_str) 5511*da0073e9SAndroid Build Coastguard Worker 5512*da0073e9SAndroid Build Coastguard Worker @unittest.skip("bailouts are being deprecated") 5513*da0073e9SAndroid Build Coastguard Worker def test_unsqueeze_guard_elimination(self): 5514*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 5515*da0073e9SAndroid Build Coastguard Worker def my_unsqueeze(x): 5516*da0073e9SAndroid Build Coastguard Worker return torch.unsqueeze(x, 0) + torch.unsqueeze(x, 0) 5517*da0073e9SAndroid Build Coastguard Worker 5518*da0073e9SAndroid Build Coastguard Worker a = torch.rand(32, 4) 5519*da0073e9SAndroid Build Coastguard Worker 5520*da0073e9SAndroid Build Coastguard Worker with enable_profiling_mode_for_profiling_tests(): 5521*da0073e9SAndroid Build Coastguard Worker my_unsqueeze(a) 5522*da0073e9SAndroid Build Coastguard Worker bailout_graph_str = str(my_unsqueeze.graph_for(a)) 5523*da0073e9SAndroid Build Coastguard Worker FileCheck().check_count("prim::BailOut", 2).run(bailout_graph_str) 5524*da0073e9SAndroid Build Coastguard Worker 5525*da0073e9SAndroid Build Coastguard Worker def test_resize_input_ops(self): 5526*da0073e9SAndroid Build Coastguard Worker # resize_ and resize_as resize the input tensor. because our shape analysis 5527*da0073e9SAndroid Build Coastguard Worker # is flow invariant, we set any Tensor that can alias a resized Tensor 5528*da0073e9SAndroid Build Coastguard Worker # to the base Tensor Type, without size information. 5529*da0073e9SAndroid Build Coastguard Worker 5530*da0073e9SAndroid Build Coastguard Worker # testing that value which is an input of a graph gets handled 5531*da0073e9SAndroid Build Coastguard Worker def out_op_graph_input(): 5532*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 5533*da0073e9SAndroid Build Coastguard Worker def test(x, y, z): 5534*da0073e9SAndroid Build Coastguard Worker torch.mul(x, y, out=z) 5535*da0073e9SAndroid Build Coastguard Worker return z 5536*da0073e9SAndroid Build Coastguard Worker 5537*da0073e9SAndroid Build Coastguard Worker graph = _propagate_shapes(test.graph, 5538*da0073e9SAndroid Build Coastguard Worker (torch.zeros(2, 1), torch.zeros(1, 2), torch.zeros(1, 1, 1)), False) 5539*da0073e9SAndroid Build Coastguard Worker self.assertTrue(next(graph.outputs()).type() == TensorType.get()) 5540*da0073e9SAndroid Build Coastguard Worker out_op_graph_input() 5541*da0073e9SAndroid Build Coastguard Worker 5542*da0073e9SAndroid Build Coastguard Worker def test_resize(): 5543*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 5544*da0073e9SAndroid Build Coastguard Worker def test(x): 5545*da0073e9SAndroid Build Coastguard Worker after_resize_alias = torch.zeros([2]) 5546*da0073e9SAndroid Build Coastguard Worker for _i in range(5): 5547*da0073e9SAndroid Build Coastguard Worker b = x + 1 5548*da0073e9SAndroid Build Coastguard Worker f = [1] 5549*da0073e9SAndroid Build Coastguard Worker before_resize_alias = b.sub_(1) 5550*da0073e9SAndroid Build Coastguard Worker # for i in range(10): 5551*da0073e9SAndroid Build Coastguard Worker f.append(1) 5552*da0073e9SAndroid Build Coastguard Worker b.resize_(f) 5553*da0073e9SAndroid Build Coastguard Worker after_resize_alias = b.add_(1) 5554*da0073e9SAndroid Build Coastguard Worker return after_resize_alias 5555*da0073e9SAndroid Build Coastguard Worker 5556*da0073e9SAndroid Build Coastguard Worker self.run_pass('constant_propagation', test.graph) 5557*da0073e9SAndroid Build Coastguard Worker g = _propagate_shapes(test.graph, (torch.zeros(1, 1),), False) 5558*da0073e9SAndroid Build Coastguard Worker resize_node = g.findNode("aten::resize_") 5559*da0073e9SAndroid Build Coastguard Worker # first input and output of b.resize_ is b 5560*da0073e9SAndroid Build Coastguard Worker self.assertTrue(next(resize_node.inputs()).type() == TensorType.get()) 5561*da0073e9SAndroid Build Coastguard Worker self.assertTrue(next(resize_node.outputs()).type() == TensorType.get()) 5562*da0073e9SAndroid Build Coastguard Worker 5563*da0073e9SAndroid Build Coastguard Worker # correctly propagates to b alias set 5564*da0073e9SAndroid Build Coastguard Worker before_resize = g.findNode("aten::sub_") 5565*da0073e9SAndroid Build Coastguard Worker self.assertTrue(next(before_resize.outputs()).type() == TensorType.get()) 5566*da0073e9SAndroid Build Coastguard Worker 5567*da0073e9SAndroid Build Coastguard Worker after_resize = g.findNode("aten::add_") 5568*da0073e9SAndroid Build Coastguard Worker self.assertTrue(next(after_resize.outputs()).type() == TensorType.get()) 5569*da0073e9SAndroid Build Coastguard Worker 5570*da0073e9SAndroid Build Coastguard Worker test_resize() 5571*da0073e9SAndroid Build Coastguard Worker 5572*da0073e9SAndroid Build Coastguard Worker def test_resize_as(): 5573*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 5574*da0073e9SAndroid Build Coastguard Worker def test(x): 5575*da0073e9SAndroid Build Coastguard Worker b = torch.zeros([2, 2]) 5576*da0073e9SAndroid Build Coastguard Worker b.resize_as_(x) 5577*da0073e9SAndroid Build Coastguard Worker return b 5578*da0073e9SAndroid Build Coastguard Worker 5579*da0073e9SAndroid Build Coastguard Worker g = test.graph 5580*da0073e9SAndroid Build Coastguard Worker self.run_pass('constant_propagation', g) 5581*da0073e9SAndroid Build Coastguard Worker g = _propagate_shapes(test.graph, (torch.zeros(1, 1),), False) 5582*da0073e9SAndroid Build Coastguard Worker 5583*da0073e9SAndroid Build Coastguard Worker # x doesn't alias a resized op so it shouldn't be set to base Tensor type 5584*da0073e9SAndroid Build Coastguard Worker self.assertTrue(next(g.inputs()).type() != TensorType.get()) 5585*da0073e9SAndroid Build Coastguard Worker # return is resized 5586*da0073e9SAndroid Build Coastguard Worker self.assertTrue(next(g.outputs()).type() == TensorType.get()) 5587*da0073e9SAndroid Build Coastguard Worker 5588*da0073e9SAndroid Build Coastguard Worker test_resize_as() 5589*da0073e9SAndroid Build Coastguard Worker 5590*da0073e9SAndroid Build Coastguard Worker def test_uninitialized(self): 5591*da0073e9SAndroid Build Coastguard Worker graph_str = """graph(): 5592*da0073e9SAndroid Build Coastguard Worker %1 : int = prim::Uninitialized() 5593*da0073e9SAndroid Build Coastguard Worker %2 : int = prim::Constant[value=1]() 5594*da0073e9SAndroid Build Coastguard Worker %3 : int = aten::add(%1, %2) 5595*da0073e9SAndroid Build Coastguard Worker return (%3) 5596*da0073e9SAndroid Build Coastguard Worker """ 5597*da0073e9SAndroid Build Coastguard Worker g = parse_ir(graph_str) 5598*da0073e9SAndroid Build Coastguard Worker m = self.createFunctionFromGraph(g) 5599*da0073e9SAndroid Build Coastguard Worker self.getExportImportCopy(m) 5600*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "expected int"): 5601*da0073e9SAndroid Build Coastguard Worker m() 5602*da0073e9SAndroid Build Coastguard Worker 5603*da0073e9SAndroid Build Coastguard Worker 5604*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(GRAPH_EXECUTOR == ProfilingMode.SIMPLE, "Simple Executor doesn't use requires_grad information") 5605*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(GRAPH_EXECUTOR == ProfilingMode.PROFILING, "Peeling is now disabled") 5606*da0073e9SAndroid Build Coastguard Worker def test_requires_grad_loop(self): 5607*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 5608*da0073e9SAndroid Build Coastguard Worker def test(x, y, z): 5609*da0073e9SAndroid Build Coastguard Worker # type: (Tensor, Tensor, int) -> Tensor 5610*da0073e9SAndroid Build Coastguard Worker for _ in range(z): 5611*da0073e9SAndroid Build Coastguard Worker x = y 5612*da0073e9SAndroid Build Coastguard Worker return x 5613*da0073e9SAndroid Build Coastguard Worker 5614*da0073e9SAndroid Build Coastguard Worker # x requires grad, y does not 5615*da0073e9SAndroid Build Coastguard Worker # testing that requires grad analysis correctly exits, with its input 5616*da0073e9SAndroid Build Coastguard Worker # to the loop (x) requiring grad and its output to the loop not requiring grad 5617*da0073e9SAndroid Build Coastguard Worker # and the output of the node conservatively setting grad to true 5618*da0073e9SAndroid Build Coastguard Worker 5619*da0073e9SAndroid Build Coastguard Worker inps = (torch.tensor(1.0, requires_grad=True), torch.tensor(1), 10) 5620*da0073e9SAndroid Build Coastguard Worker test(*inps, profile_and_replay=True) 5621*da0073e9SAndroid Build Coastguard Worker 5622*da0073e9SAndroid Build Coastguard Worker graph = test.graph_for(*inps) 5623*da0073e9SAndroid Build Coastguard Worker loop = graph.findNode("prim::Loop") 5624*da0073e9SAndroid Build Coastguard Worker loop_body = next(loop.blocks()) 5625*da0073e9SAndroid Build Coastguard Worker loop_inputs = list(loop_body.inputs()) 5626*da0073e9SAndroid Build Coastguard Worker loop_outputs = list(loop_body.outputs()) 5627*da0073e9SAndroid Build Coastguard Worker 5628*da0073e9SAndroid Build Coastguard Worker if GRAPH_EXECUTOR == ProfilingMode.PROFILING: 5629*da0073e9SAndroid Build Coastguard Worker # TODO: simplify this test as it's very sensitive 5630*da0073e9SAndroid Build Coastguard Worker # the optimized graph will have 3 loops 5631*da0073e9SAndroid Build Coastguard Worker # the original loop is peeled 5632*da0073e9SAndroid Build Coastguard Worker # peeled loop also gets unrolled 5633*da0073e9SAndroid Build Coastguard Worker index_of_x_in_peeled_unrolled_loop = -2 5634*da0073e9SAndroid Build Coastguard Worker self.assertTrue(loop_inputs[index_of_x_in_peeled_unrolled_loop].requires_grad()) 5635*da0073e9SAndroid Build Coastguard Worker bailouts_in_outer_block = graph.findAllNodes("prim::BailOut", False) 5636*da0073e9SAndroid Build Coastguard Worker last_bailout_index_on_loops_output = -1 5637*da0073e9SAndroid Build Coastguard Worker self.assertFalse(bailouts_in_outer_block[last_bailout_index_on_loops_output].output().requires_grad()) 5638*da0073e9SAndroid Build Coastguard Worker else: 5639*da0073e9SAndroid Build Coastguard Worker self.assertTrue(loop_inputs[1].requires_grad()) 5640*da0073e9SAndroid Build Coastguard Worker self.assertTrue(loop.output().requires_grad()) 5641*da0073e9SAndroid Build Coastguard Worker self.assertFalse(loop_outputs[1].requires_grad()) 5642*da0073e9SAndroid Build Coastguard Worker 5643*da0073e9SAndroid Build Coastguard Worker def test_view_shape_prop(self): 5644*da0073e9SAndroid Build Coastguard Worker cu = torch.jit.CompilationUnit(''' 5645*da0073e9SAndroid Build Coastguard Worker def test_view_shape_prop(a): 5646*da0073e9SAndroid Build Coastguard Worker return a.view(size=[-1]) 5647*da0073e9SAndroid Build Coastguard Worker ''') 5648*da0073e9SAndroid Build Coastguard Worker inputs = [torch.zeros(10, 10)] 5649*da0073e9SAndroid Build Coastguard Worker outputs = torch.zeros(100) 5650*da0073e9SAndroid Build Coastguard Worker 5651*da0073e9SAndroid Build Coastguard Worker real_outs = cu.test_view_shape_prop(*inputs) 5652*da0073e9SAndroid Build Coastguard Worker self.assertEqual(real_outs, outputs) 5653*da0073e9SAndroid Build Coastguard Worker 5654*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("TorchDynamo fails with unknown reason") 5655*da0073e9SAndroid Build Coastguard Worker def test_view_listconstruct_shape_prop(self): 5656*da0073e9SAndroid Build Coastguard Worker def fn(x): 5657*da0073e9SAndroid Build Coastguard Worker B = x.size(0) 5658*da0073e9SAndroid Build Coastguard Worker C = x.size(1) 5659*da0073e9SAndroid Build Coastguard Worker T = x.size(2) 5660*da0073e9SAndroid Build Coastguard Worker return x.view(T, B, C) 5661*da0073e9SAndroid Build Coastguard Worker 5662*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, 1, 5, requires_grad=True) 5663*da0073e9SAndroid Build Coastguard Worker fn = torch.jit.script(fn) 5664*da0073e9SAndroid Build Coastguard Worker graph = _propagate_shapes(fn.graph, (x,), False) 5665*da0073e9SAndroid Build Coastguard Worker self.assertTrue(next(graph.outputs()).type().scalarType() == 'Float') 5666*da0073e9SAndroid Build Coastguard Worker 5667*da0073e9SAndroid Build Coastguard Worker def test_shape_prop_promotion(self): 5668*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 5669*da0073e9SAndroid Build Coastguard Worker def fn(x, y): 5670*da0073e9SAndroid Build Coastguard Worker return x + y 5671*da0073e9SAndroid Build Coastguard Worker 5672*da0073e9SAndroid Build Coastguard Worker x, y = torch.rand(3, 4, dtype=torch.float), torch.rand(3, 4, dtype=torch.double) 5673*da0073e9SAndroid Build Coastguard Worker graph = _propagate_shapes(fn.graph, (x, y), False) 5674*da0073e9SAndroid Build Coastguard Worker FileCheck().check('Double(*, *, device=cpu) = aten::add').run(graph) 5675*da0073e9SAndroid Build Coastguard Worker 5676*da0073e9SAndroid Build Coastguard Worker def test_shape_prop_promote_scalar_arg(self): 5677*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 5678*da0073e9SAndroid Build Coastguard Worker def fn(x): 5679*da0073e9SAndroid Build Coastguard Worker return math.pi + x 5680*da0073e9SAndroid Build Coastguard Worker 5681*da0073e9SAndroid Build Coastguard Worker x = torch.zeros(3, 4, dtype=torch.long) 5682*da0073e9SAndroid Build Coastguard Worker graph = _propagate_shapes(fn.graph, (x,), False) 5683*da0073e9SAndroid Build Coastguard Worker default = torch.get_default_dtype() 5684*da0073e9SAndroid Build Coastguard Worker if default == torch.float: 5685*da0073e9SAndroid Build Coastguard Worker FileCheck().check('Float(*, *, requires_grad=0, device=cpu) = aten::add').run(graph) 5686*da0073e9SAndroid Build Coastguard Worker else: 5687*da0073e9SAndroid Build Coastguard Worker FileCheck().check('Double(*, *, requires_grad=0, device=cpu) = aten::add').run(graph) 5688*da0073e9SAndroid Build Coastguard Worker 5689*da0073e9SAndroid Build Coastguard Worker def test_integral_shape_inference(self): 5690*da0073e9SAndroid Build Coastguard Worker cu = torch.jit.CompilationUnit(''' 5691*da0073e9SAndroid Build Coastguard Worker def test_integral_shape_inference(a): 5692*da0073e9SAndroid Build Coastguard Worker return a * a 5693*da0073e9SAndroid Build Coastguard Worker ''') 5694*da0073e9SAndroid Build Coastguard Worker inputs = [torch.ones(10, 10, dtype=torch.long)] 5695*da0073e9SAndroid Build Coastguard Worker outputs = torch.ones(10, 10, dtype=torch.long) 5696*da0073e9SAndroid Build Coastguard Worker 5697*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cu.test_integral_shape_inference(*inputs), outputs) 5698*da0073e9SAndroid Build Coastguard Worker 5699*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(RUN_CUDA, 'This tests the CPU fuser') 5700*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_SANDCASTLE, "NYI: fuser support for Sandcastle") 5701*da0073e9SAndroid Build Coastguard Worker @enable_cpu_fuser 5702*da0073e9SAndroid Build Coastguard Worker def test_batchnorm_fuser_cpu(self): 5703*da0073e9SAndroid Build Coastguard Worker code = ''' 5704*da0073e9SAndroid Build Coastguard Worker graph(%3 : Tensor, 5705*da0073e9SAndroid Build Coastguard Worker %7 : Tensor, 5706*da0073e9SAndroid Build Coastguard Worker %12 : Float(*, *), 5707*da0073e9SAndroid Build Coastguard Worker %13 : Tensor, 5708*da0073e9SAndroid Build Coastguard Worker %25 : Tensor): 5709*da0073e9SAndroid Build Coastguard Worker %23 : int = prim::Constant[value=1]() 5710*da0073e9SAndroid Build Coastguard Worker %22 : float = prim::Constant[value=1e-05]() 5711*da0073e9SAndroid Build Coastguard Worker %26 : Tensor = aten::sqrt(%25) 5712*da0073e9SAndroid Build Coastguard Worker %24 : Tensor = aten::add(%26, %22, %23) 5713*da0073e9SAndroid Build Coastguard Worker %20 : Tensor = aten::reciprocal(%24) 5714*da0073e9SAndroid Build Coastguard Worker %norm_invstd : Tensor = aten::mul(%20, %23) 5715*da0073e9SAndroid Build Coastguard Worker %15 : Tensor = aten::sub(%12, %13, %23) 5716*da0073e9SAndroid Build Coastguard Worker %11 : Tensor = aten::mul(%15, %norm_invstd) 5717*da0073e9SAndroid Build Coastguard Worker %8 : Tensor = aten::mul(%11, %7) 5718*da0073e9SAndroid Build Coastguard Worker %5 : Tensor = aten::add(%8, %3, %23) 5719*da0073e9SAndroid Build Coastguard Worker %1 : Float(*, *) = aten::relu(%5) 5720*da0073e9SAndroid Build Coastguard Worker return (%1) 5721*da0073e9SAndroid Build Coastguard Worker ''' 5722*da0073e9SAndroid Build Coastguard Worker 5723*da0073e9SAndroid Build Coastguard Worker graph = parse_ir(code) 5724*da0073e9SAndroid Build Coastguard Worker inputs = 5 * [torch.rand(26, 2048, dtype=torch.float)] 5725*da0073e9SAndroid Build Coastguard Worker code = torch._C._jit_fuser_get_fused_kernel_code(graph, inputs) 5726*da0073e9SAndroid Build Coastguard Worker FileCheck().check('sqrtf').run(code) 5727*da0073e9SAndroid Build Coastguard Worker 5728*da0073e9SAndroid Build Coastguard Worker @slowTest 5729*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(RUN_CUDA, 'This tests the CPU fuser') 5730*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_SANDCASTLE, "NYI: fuser support for Sandcastle") 5731*da0073e9SAndroid Build Coastguard Worker @enable_cpu_fuser 5732*da0073e9SAndroid Build Coastguard Worker def test_fuser_double_float_codegen(self): 5733*da0073e9SAndroid Build Coastguard Worker fns = ['log', 'log10', 'log1p', 'log2', 'lgamma', 'exp', 'expm1', 'erf', 5734*da0073e9SAndroid Build Coastguard Worker 'erfc', 'cos', 'acos', 'cosh', 'sin', 'asin', 'sinh', 'tan', 5735*da0073e9SAndroid Build Coastguard Worker 'atan', 'tanh', 'sqrt', 'ceil', 'floor', 'round', 'trunc', 5736*da0073e9SAndroid Build Coastguard Worker 'frac'] 5737*da0073e9SAndroid Build Coastguard Worker 5738*da0073e9SAndroid Build Coastguard Worker def lookup_c_equivalent_fn(aten_fn): 5739*da0073e9SAndroid Build Coastguard Worker return aten_fn 5740*da0073e9SAndroid Build Coastguard Worker 5741*da0073e9SAndroid Build Coastguard Worker def test_dispatch(op, expects, dtype, binary=False): 5742*da0073e9SAndroid Build Coastguard Worker if dtype == torch.double: 5743*da0073e9SAndroid Build Coastguard Worker dtype_str = 'Double' 5744*da0073e9SAndroid Build Coastguard Worker elif dtype == torch.float: 5745*da0073e9SAndroid Build Coastguard Worker dtype_str = 'Float' 5746*da0073e9SAndroid Build Coastguard Worker else: 5747*da0073e9SAndroid Build Coastguard Worker raise RuntimeError('Unknown dtype') 5748*da0073e9SAndroid Build Coastguard Worker 5749*da0073e9SAndroid Build Coastguard Worker if binary: 5750*da0073e9SAndroid Build Coastguard Worker code = f''' 5751*da0073e9SAndroid Build Coastguard Worker graph(%3 : Tensor, %4 : Tensor): 5752*da0073e9SAndroid Build Coastguard Worker %2 : {dtype_str}(*, *) = aten::{op}(%3, %4) 5753*da0073e9SAndroid Build Coastguard Worker %1 : {dtype_str}(*, *) = aten::relu(%2) 5754*da0073e9SAndroid Build Coastguard Worker return (%1) 5755*da0073e9SAndroid Build Coastguard Worker ''' 5756*da0073e9SAndroid Build Coastguard Worker else: 5757*da0073e9SAndroid Build Coastguard Worker code = f''' 5758*da0073e9SAndroid Build Coastguard Worker graph(%3 : Tensor): 5759*da0073e9SAndroid Build Coastguard Worker %2 : {dtype_str}(*, *) = aten::{op}(%3) 5760*da0073e9SAndroid Build Coastguard Worker %1 : {dtype_str}(*, *) = aten::relu(%2) 5761*da0073e9SAndroid Build Coastguard Worker return (%1) 5762*da0073e9SAndroid Build Coastguard Worker ''' 5763*da0073e9SAndroid Build Coastguard Worker 5764*da0073e9SAndroid Build Coastguard Worker graph = parse_ir(code) 5765*da0073e9SAndroid Build Coastguard Worker inputs = (2 if binary else 1) * [torch.rand(26, 2048, dtype=dtype)] 5766*da0073e9SAndroid Build Coastguard Worker code = torch._C._jit_fuser_get_fused_kernel_code(graph, inputs) 5767*da0073e9SAndroid Build Coastguard Worker FileCheck().check(expects).run(code) 5768*da0073e9SAndroid Build Coastguard Worker 5769*da0073e9SAndroid Build Coastguard Worker for fn in fns: 5770*da0073e9SAndroid Build Coastguard Worker test_dispatch(fn, lookup_c_equivalent_fn(fn) + '(', torch.double) 5771*da0073e9SAndroid Build Coastguard Worker test_dispatch(fn, lookup_c_equivalent_fn(fn) + 'f(', torch.float) 5772*da0073e9SAndroid Build Coastguard Worker 5773*da0073e9SAndroid Build Coastguard Worker # 'min', 'max' were previously tested but are now replaced with ternary expressions 5774*da0073e9SAndroid Build Coastguard Worker # instead of fmin() and fmax() 5775*da0073e9SAndroid Build Coastguard Worker binary_fns = ['pow'] 5776*da0073e9SAndroid Build Coastguard Worker for fn in binary_fns: 5777*da0073e9SAndroid Build Coastguard Worker test_dispatch(fn, lookup_c_equivalent_fn(fn) + '(', torch.double, binary=True) 5778*da0073e9SAndroid Build Coastguard Worker test_dispatch(fn, lookup_c_equivalent_fn(fn) + 'f(', torch.float, binary=True) 5779*da0073e9SAndroid Build Coastguard Worker 5780*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(RUN_CUDA, 'This tests the CPU fuser') 5781*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_SANDCASTLE, "NYI: fuser support for Sandcastle") 5782*da0073e9SAndroid Build Coastguard Worker @enable_cpu_fuser 5783*da0073e9SAndroid Build Coastguard Worker def test_fuser_double_literal_precision(self): 5784*da0073e9SAndroid Build Coastguard Worker code = ''' 5785*da0073e9SAndroid Build Coastguard Worker graph(%2 : Float(*, *)): 5786*da0073e9SAndroid Build Coastguard Worker %4 : int = prim::Constant[value=1]() 5787*da0073e9SAndroid Build Coastguard Worker %3 : float = prim::Constant[value=1.282549830161864]() 5788*da0073e9SAndroid Build Coastguard Worker %5 : Float(*, *) = aten::add(%2, %3, %4) 5789*da0073e9SAndroid Build Coastguard Worker %1 : Float(*, *) = aten::relu(%5) 5790*da0073e9SAndroid Build Coastguard Worker return (%1) 5791*da0073e9SAndroid Build Coastguard Worker ''' 5792*da0073e9SAndroid Build Coastguard Worker 5793*da0073e9SAndroid Build Coastguard Worker graph = parse_ir(code) 5794*da0073e9SAndroid Build Coastguard Worker code = torch._C._jit_fuser_get_fused_kernel_code(graph, [torch.rand(3, 4)]) 5795*da0073e9SAndroid Build Coastguard Worker FileCheck().check('1.282549830161864').run(code) 5796*da0073e9SAndroid Build Coastguard Worker 5797*da0073e9SAndroid Build Coastguard Worker def test_fuser_multiple_blocks(self): 5798*da0073e9SAndroid Build Coastguard Worker cu = torch.jit.CompilationUnit(''' 5799*da0073e9SAndroid Build Coastguard Worker def test_fuser_multiple_blocks(this, that, theother, meme): 5800*da0073e9SAndroid Build Coastguard Worker i = 0 5801*da0073e9SAndroid Build Coastguard Worker while i < 20: 5802*da0073e9SAndroid Build Coastguard Worker this = torch.cat([this, meme], dim=0) 5803*da0073e9SAndroid Build Coastguard Worker that = torch.cat([that, meme], dim=0) 5804*da0073e9SAndroid Build Coastguard Worker theother = torch.cat([theother, meme], dim=0) 5805*da0073e9SAndroid Build Coastguard Worker i = i + 1 5806*da0073e9SAndroid Build Coastguard Worker return this, that, theother 5807*da0073e9SAndroid Build Coastguard Worker ''') 5808*da0073e9SAndroid Build Coastguard Worker 5809*da0073e9SAndroid Build Coastguard Worker inputs = [torch.ones(0, 10, 10)] * 3 5810*da0073e9SAndroid Build Coastguard Worker inputs += [torch.ones(1, 10, 10)] 5811*da0073e9SAndroid Build Coastguard Worker outputs = [torch.ones(20, 10, 10)] * 3 5812*da0073e9SAndroid Build Coastguard Worker 5813*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cu.test_fuser_multiple_blocks(*inputs), outputs) 5814*da0073e9SAndroid Build Coastguard Worker 5815*da0073e9SAndroid Build Coastguard Worker @unittest.skip("RuntimeError: VariableType::ID() not implemented") 5816*da0073e9SAndroid Build Coastguard Worker def test_cast(self): 5817*da0073e9SAndroid Build Coastguard Worker script = ''' 5818*da0073e9SAndroid Build Coastguard Worker def to_int(x): 5819*da0073e9SAndroid Build Coastguard Worker return int(x) 5820*da0073e9SAndroid Build Coastguard Worker ''' 5821*da0073e9SAndroid Build Coastguard Worker x = Variable(torch.FloatTensor([1.1, 2.3]), requires_grad=True) 5822*da0073e9SAndroid Build Coastguard Worker out = Variable(torch.IntTensor([1, 2]), requires_grad=True) 5823*da0073e9SAndroid Build Coastguard Worker self.checkScript(script, [x], optimize=True, outputs=[out], func='to_int') 5824*da0073e9SAndroid Build Coastguard Worker 5825*da0073e9SAndroid Build Coastguard Worker def test_str_cast(self): 5826*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 5827*da0073e9SAndroid Build Coastguard Worker def to_str(x): 5828*da0073e9SAndroid Build Coastguard Worker # type: (int) -> str 5829*da0073e9SAndroid Build Coastguard Worker return str((x, x)) 5830*da0073e9SAndroid Build Coastguard Worker 5831*da0073e9SAndroid Build Coastguard Worker self.assertEqual("(1, 1)", to_str(1)) 5832*da0073e9SAndroid Build Coastguard Worker 5833*da0073e9SAndroid Build Coastguard Worker def test_int_cast(self): 5834*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 5835*da0073e9SAndroid Build Coastguard Worker def to_int(x): 5836*da0073e9SAndroid Build Coastguard Worker # type: (str) -> int 5837*da0073e9SAndroid Build Coastguard Worker return int(x) 5838*da0073e9SAndroid Build Coastguard Worker 5839*da0073e9SAndroid Build Coastguard Worker self.assertEqual(5, to_int('5')) 5840*da0073e9SAndroid Build Coastguard Worker self.assertEqual(-5, to_int('-5')) 5841*da0073e9SAndroid Build Coastguard Worker self.assertEqual(2147483647, to_int('2147483647')) 5842*da0073e9SAndroid Build Coastguard Worker self.assertEqual(-2147483648, to_int('-2147483648')) 5843*da0073e9SAndroid Build Coastguard Worker 5844*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "invalid literal for int()"): 5845*da0073e9SAndroid Build Coastguard Worker to_int('0x20') 5846*da0073e9SAndroid Build Coastguard Worker 5847*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "invalid literal for int()"): 5848*da0073e9SAndroid Build Coastguard Worker to_int('0b0001') 5849*da0073e9SAndroid Build Coastguard Worker 5850*da0073e9SAndroid Build Coastguard Worker def test_python_frontend(self): 5851*da0073e9SAndroid Build Coastguard Worker def fn(x, y, z): 5852*da0073e9SAndroid Build Coastguard Worker q = None 5853*da0073e9SAndroid Build Coastguard Worker q = x + y - z.sigmoid() 5854*da0073e9SAndroid Build Coastguard Worker print(q) 5855*da0073e9SAndroid Build Coastguard Worker w = -z 5856*da0073e9SAndroid Build Coastguard Worker if not x and not y and z: 5857*da0073e9SAndroid Build Coastguard Worker m = x if not z else y 5858*da0073e9SAndroid Build Coastguard Worker while x < y > z: 5859*da0073e9SAndroid Build Coastguard Worker q = x 5860*da0073e9SAndroid Build Coastguard Worker assert 1 == 1, "hello" 5861*da0073e9SAndroid Build Coastguard Worker return x 5862*da0073e9SAndroid Build Coastguard Worker 5863*da0073e9SAndroid Build Coastguard Worker ast = torch.jit.frontend.get_jit_def(fn, fn.__name__) 5864*da0073e9SAndroid Build Coastguard Worker self.assertExpected(str(ast)) 5865*da0073e9SAndroid Build Coastguard Worker 5866*da0073e9SAndroid Build Coastguard Worker def test_python_frontend_source_range(self): 5867*da0073e9SAndroid Build Coastguard Worker def fn(): 5868*da0073e9SAndroid Build Coastguard Worker raise Exception("hello") # noqa: TRY002 5869*da0073e9SAndroid Build Coastguard Worker ast = torch.jit.frontend.get_jit_def(fn, fn.__name__) 5870*da0073e9SAndroid Build Coastguard Worker FileCheck().check("SourceRange at:") \ 5871*da0073e9SAndroid Build Coastguard Worker .check("def fn():") \ 5872*da0073e9SAndroid Build Coastguard Worker .check("~~~~~~~~~") \ 5873*da0073e9SAndroid Build Coastguard Worker .check('raise Exception("hello")') \ 5874*da0073e9SAndroid Build Coastguard Worker .check('~~~~~~~~~~~~~~~~~ <--- HERE') \ 5875*da0073e9SAndroid Build Coastguard Worker .run(str(ast.range())) 5876*da0073e9SAndroid Build Coastguard Worker 5877*da0073e9SAndroid Build Coastguard Worker def test_python_frontend_py3(self): 5878*da0073e9SAndroid Build Coastguard Worker def fn(): 5879*da0073e9SAndroid Build Coastguard Worker raise Exception("hello") # noqa: TRY002 5880*da0073e9SAndroid Build Coastguard Worker ast = torch.jit.frontend.get_jit_def(fn, fn.__name__) 5881*da0073e9SAndroid Build Coastguard Worker self.assertExpected(str(ast)) 5882*da0073e9SAndroid Build Coastguard Worker 5883*da0073e9SAndroid Build Coastguard Worker def _make_scalar_vars(self, arr, dtype): 5884*da0073e9SAndroid Build Coastguard Worker return [torch.tensor(val, dtype=dtype) for val in arr] 5885*da0073e9SAndroid Build Coastguard Worker 5886*da0073e9SAndroid Build Coastguard Worker 5887*da0073e9SAndroid Build Coastguard Worker def test_string_print(self): 5888*da0073e9SAndroid Build Coastguard Worker def func(a): 5889*da0073e9SAndroid Build Coastguard Worker print(a, "a" 'b' '''c''' """d""", 2, 1.5) 5890*da0073e9SAndroid Build Coastguard Worker return a 5891*da0073e9SAndroid Build Coastguard Worker 5892*da0073e9SAndroid Build Coastguard Worker inputs = self._make_scalar_vars([1], torch.int64) 5893*da0073e9SAndroid Build Coastguard Worker self.checkScript(func, inputs, capture_output=True) 5894*da0073e9SAndroid Build Coastguard Worker 5895*da0073e9SAndroid Build Coastguard Worker def test_while(self): 5896*da0073e9SAndroid Build Coastguard Worker def func(a, b, max): 5897*da0073e9SAndroid Build Coastguard Worker while bool(a < max): 5898*da0073e9SAndroid Build Coastguard Worker a = a + 1 5899*da0073e9SAndroid Build Coastguard Worker b = b + 1 5900*da0073e9SAndroid Build Coastguard Worker c = a + b 5901*da0073e9SAndroid Build Coastguard Worker return c 5902*da0073e9SAndroid Build Coastguard Worker 5903*da0073e9SAndroid Build Coastguard Worker inputs = self._make_scalar_vars([1, 1, 10], torch.int64) 5904*da0073e9SAndroid Build Coastguard Worker self.checkScript(func, inputs, optimize=True) 5905*da0073e9SAndroid Build Coastguard Worker 5906*da0073e9SAndroid Build Coastguard Worker def test_fibb(self): 5907*da0073e9SAndroid Build Coastguard Worker def func(lim): 5908*da0073e9SAndroid Build Coastguard Worker first = 1 5909*da0073e9SAndroid Build Coastguard Worker second = 1 5910*da0073e9SAndroid Build Coastguard Worker i = 1 5911*da0073e9SAndroid Build Coastguard Worker somenum = 5 5912*da0073e9SAndroid Build Coastguard Worker dontmutateme = 3 5913*da0073e9SAndroid Build Coastguard Worker third = 0 5914*da0073e9SAndroid Build Coastguard Worker while bool(i < lim): 5915*da0073e9SAndroid Build Coastguard Worker third = first + second 5916*da0073e9SAndroid Build Coastguard Worker first = second 5917*da0073e9SAndroid Build Coastguard Worker second = third 5918*da0073e9SAndroid Build Coastguard Worker j = 0 5919*da0073e9SAndroid Build Coastguard Worker while j < 10: 5920*da0073e9SAndroid Build Coastguard Worker somenum = somenum * 2 5921*da0073e9SAndroid Build Coastguard Worker j = j + 1 5922*da0073e9SAndroid Build Coastguard Worker i = i + j 5923*da0073e9SAndroid Build Coastguard Worker i = i + dontmutateme 5924*da0073e9SAndroid Build Coastguard Worker 5925*da0073e9SAndroid Build Coastguard Worker st = second + third 5926*da0073e9SAndroid Build Coastguard Worker fs = first + second 5927*da0073e9SAndroid Build Coastguard Worker return third, st, fs 5928*da0073e9SAndroid Build Coastguard Worker 5929*da0073e9SAndroid Build Coastguard Worker inputs = self._make_scalar_vars([10], torch.int64) 5930*da0073e9SAndroid Build Coastguard Worker self.checkScript(func, inputs, optimize=True) 5931*da0073e9SAndroid Build Coastguard Worker 5932*da0073e9SAndroid Build Coastguard Worker def test_fibb_totally_better(self): 5933*da0073e9SAndroid Build Coastguard Worker def fib(x): 5934*da0073e9SAndroid Build Coastguard Worker # type: (int) -> int 5935*da0073e9SAndroid Build Coastguard Worker prev = 1 5936*da0073e9SAndroid Build Coastguard Worker v = 1 5937*da0073e9SAndroid Build Coastguard Worker for i in range(0, x): 5938*da0073e9SAndroid Build Coastguard Worker save = v 5939*da0073e9SAndroid Build Coastguard Worker v = v + prev 5940*da0073e9SAndroid Build Coastguard Worker prev = save 5941*da0073e9SAndroid Build Coastguard Worker return v 5942*da0073e9SAndroid Build Coastguard Worker 5943*da0073e9SAndroid Build Coastguard Worker self.checkScript(fib, (10,)) 5944*da0073e9SAndroid Build Coastguard Worker 5945*da0073e9SAndroid Build Coastguard Worker def test_if(self): 5946*da0073e9SAndroid Build Coastguard Worker def func(a, b): 5947*da0073e9SAndroid Build Coastguard Worker # type: (int, int) -> int 5948*da0073e9SAndroid Build Coastguard Worker d = 3 5949*da0073e9SAndroid Build Coastguard Worker if bool(a > 10): 5950*da0073e9SAndroid Build Coastguard Worker a = 3 + d 5951*da0073e9SAndroid Build Coastguard Worker else: 5952*da0073e9SAndroid Build Coastguard Worker b = 3 + d 5953*da0073e9SAndroid Build Coastguard Worker d = 4 5954*da0073e9SAndroid Build Coastguard Worker c = a + b 5955*da0073e9SAndroid Build Coastguard Worker return c 5956*da0073e9SAndroid Build Coastguard Worker 5957*da0073e9SAndroid Build Coastguard Worker inputs = self._make_scalar_vars([1, -1], torch.int64) 5958*da0073e9SAndroid Build Coastguard Worker self.checkScript(func, inputs, optimize=True) 5959*da0073e9SAndroid Build Coastguard Worker 5960*da0073e9SAndroid Build Coastguard Worker def test_if_for_in_range(self): 5961*da0073e9SAndroid Build Coastguard Worker def func(a, b): 5962*da0073e9SAndroid Build Coastguard Worker # type: (int, int) -> int 5963*da0073e9SAndroid Build Coastguard Worker d = 3 5964*da0073e9SAndroid Build Coastguard Worker for _ in range(20): 5965*da0073e9SAndroid Build Coastguard Worker if bool(a > 10): 5966*da0073e9SAndroid Build Coastguard Worker a = 3 + d 5967*da0073e9SAndroid Build Coastguard Worker else: 5968*da0073e9SAndroid Build Coastguard Worker b = 3 + d 5969*da0073e9SAndroid Build Coastguard Worker d = 4 5970*da0073e9SAndroid Build Coastguard Worker c = a + b 5971*da0073e9SAndroid Build Coastguard Worker return d 5972*da0073e9SAndroid Build Coastguard Worker inputs = self._make_scalar_vars([1, -1], torch.int64) 5973*da0073e9SAndroid Build Coastguard Worker self.checkScript(func, inputs, optimize=True) 5974*da0073e9SAndroid Build Coastguard Worker 5975*da0073e9SAndroid Build Coastguard Worker def test_if_noelse(self): 5976*da0073e9SAndroid Build Coastguard Worker def func(a, b): 5977*da0073e9SAndroid Build Coastguard Worker if bool(a > 10): 5978*da0073e9SAndroid Build Coastguard Worker a = 3 + b 5979*da0073e9SAndroid Build Coastguard Worker c = a + b 5980*da0073e9SAndroid Build Coastguard Worker return c 5981*da0073e9SAndroid Build Coastguard Worker 5982*da0073e9SAndroid Build Coastguard Worker inputs = self._make_scalar_vars([-1, 1], torch.int64) 5983*da0073e9SAndroid Build Coastguard Worker self.checkScript(func, inputs, optimize=True) 5984*da0073e9SAndroid Build Coastguard Worker 5985*da0073e9SAndroid Build Coastguard Worker def test_if_is_none_dispatch(self): 5986*da0073e9SAndroid Build Coastguard Worker 5987*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 5988*da0073e9SAndroid Build Coastguard Worker def test_lhs_none_rhs_none(): 5989*da0073e9SAndroid Build Coastguard Worker # LHS, RHS both alwaysNone, dispatch always_none_branch 5990*da0073e9SAndroid Build Coastguard Worker # only emit one prim::Constant 5991*da0073e9SAndroid Build Coastguard Worker if None is None: 5992*da0073e9SAndroid Build Coastguard Worker return 1 5993*da0073e9SAndroid Build Coastguard Worker elif None is not None: 5994*da0073e9SAndroid Build Coastguard Worker return 2 5995*da0073e9SAndroid Build Coastguard Worker else: 5996*da0073e9SAndroid Build Coastguard Worker return 3 5997*da0073e9SAndroid Build Coastguard Worker 5998*da0073e9SAndroid Build Coastguard Worker self.assertTrue(str(test_lhs_none_rhs_none.graph).count(': int = prim::Constant') == 1) 5999*da0073e9SAndroid Build Coastguard Worker 6000*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 6001*da0073e9SAndroid Build Coastguard Worker def test_lhs_opt_rhs_none(lhs=None): 6002*da0073e9SAndroid Build Coastguard Worker # type: (Optional[Tensor]) -> int 6003*da0073e9SAndroid Build Coastguard Worker # LHS maybeNone: emit normal if stmt that contains 3 constants 6004*da0073e9SAndroid Build Coastguard Worker if lhs is not None: 6005*da0073e9SAndroid Build Coastguard Worker return 2 6006*da0073e9SAndroid Build Coastguard Worker elif lhs is None: 6007*da0073e9SAndroid Build Coastguard Worker return 1 6008*da0073e9SAndroid Build Coastguard Worker else: 6009*da0073e9SAndroid Build Coastguard Worker return 3 6010*da0073e9SAndroid Build Coastguard Worker 6011*da0073e9SAndroid Build Coastguard Worker self.assertTrue(str(test_lhs_opt_rhs_none.graph).count(': int = prim::Constant') == 3) 6012*da0073e9SAndroid Build Coastguard Worker 6013*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 6014*da0073e9SAndroid Build Coastguard Worker def test_lhs_none_rhs_opt(rhs=None): 6015*da0073e9SAndroid Build Coastguard Worker # type: (Optional[Tensor]) -> int 6016*da0073e9SAndroid Build Coastguard Worker # RHS maybeNone, emit normal if stmt that contains 3 constants 6017*da0073e9SAndroid Build Coastguard Worker if None is rhs: 6018*da0073e9SAndroid Build Coastguard Worker return 1 6019*da0073e9SAndroid Build Coastguard Worker elif None is not rhs: 6020*da0073e9SAndroid Build Coastguard Worker return 2 6021*da0073e9SAndroid Build Coastguard Worker else: 6022*da0073e9SAndroid Build Coastguard Worker return 3 6023*da0073e9SAndroid Build Coastguard Worker 6024*da0073e9SAndroid Build Coastguard Worker self.assertTrue(str(test_lhs_opt_rhs_none.graph).count(': int = prim::Constant') == 3) 6025*da0073e9SAndroid Build Coastguard Worker 6026*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 6027*da0073e9SAndroid Build Coastguard Worker def test_lhs_never_rhs_none(lhs): 6028*da0073e9SAndroid Build Coastguard Worker # LHS neverNone, RHS alwaysNone dispatch never_none_branch 6029*da0073e9SAndroid Build Coastguard Worker # only emit one prim::Constant 6030*da0073e9SAndroid Build Coastguard Worker if lhs is None: 6031*da0073e9SAndroid Build Coastguard Worker return 1 6032*da0073e9SAndroid Build Coastguard Worker elif lhs is not None: 6033*da0073e9SAndroid Build Coastguard Worker return 2 6034*da0073e9SAndroid Build Coastguard Worker else: 6035*da0073e9SAndroid Build Coastguard Worker return 3 6036*da0073e9SAndroid Build Coastguard Worker 6037*da0073e9SAndroid Build Coastguard Worker self.assertTrue(str(test_lhs_never_rhs_none.graph).count(': int = prim::Constant') == 1) 6038*da0073e9SAndroid Build Coastguard Worker 6039*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 6040*da0073e9SAndroid Build Coastguard Worker def test_lhs_none_rhs_never(rhs): 6041*da0073e9SAndroid Build Coastguard Worker # LHS alwaysNone, RHS neverNone dispatch never_none_branch 6042*da0073e9SAndroid Build Coastguard Worker # only emit one prim::Constant 6043*da0073e9SAndroid Build Coastguard Worker if None is rhs: 6044*da0073e9SAndroid Build Coastguard Worker return 1 6045*da0073e9SAndroid Build Coastguard Worker elif None is not rhs: 6046*da0073e9SAndroid Build Coastguard Worker return 2 6047*da0073e9SAndroid Build Coastguard Worker else: 6048*da0073e9SAndroid Build Coastguard Worker return 3 6049*da0073e9SAndroid Build Coastguard Worker 6050*da0073e9SAndroid Build Coastguard Worker self.assertTrue(str(test_lhs_none_rhs_never.graph).count(': int = prim::Constant') == 1) 6051*da0073e9SAndroid Build Coastguard Worker 6052*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 6053*da0073e9SAndroid Build Coastguard Worker def test_bool_arith_and(lhs): 6054*da0073e9SAndroid Build Coastguard Worker if lhs is None and lhs is not None: 6055*da0073e9SAndroid Build Coastguard Worker return 1 6056*da0073e9SAndroid Build Coastguard Worker else: 6057*da0073e9SAndroid Build Coastguard Worker return 2 6058*da0073e9SAndroid Build Coastguard Worker self.assertEqual(test_bool_arith_and(torch.zeros(3)), 2) 6059*da0073e9SAndroid Build Coastguard Worker self.assertTrue(str(test_bool_arith_and.graph).count('if') == 0) 6060*da0073e9SAndroid Build Coastguard Worker 6061*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 6062*da0073e9SAndroid Build Coastguard Worker def test_bool_arith_or(lhs): 6063*da0073e9SAndroid Build Coastguard Worker if lhs is None or lhs is not None: 6064*da0073e9SAndroid Build Coastguard Worker return 1 6065*da0073e9SAndroid Build Coastguard Worker else: 6066*da0073e9SAndroid Build Coastguard Worker return 2 6067*da0073e9SAndroid Build Coastguard Worker self.assertEqual(test_bool_arith_or(torch.zeros(3)), 1) 6068*da0073e9SAndroid Build Coastguard Worker self.assertTrue(str(test_bool_arith_or.graph).count('if') == 0) 6069*da0073e9SAndroid Build Coastguard Worker 6070*da0073e9SAndroid Build Coastguard Worker 6071*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 6072*da0073e9SAndroid Build Coastguard Worker def test_bool_arith_not(lhs): 6073*da0073e9SAndroid Build Coastguard Worker if lhs is not None: 6074*da0073e9SAndroid Build Coastguard Worker return 1 6075*da0073e9SAndroid Build Coastguard Worker else: 6076*da0073e9SAndroid Build Coastguard Worker return 2 6077*da0073e9SAndroid Build Coastguard Worker self.assertEqual(test_bool_arith_not(torch.zeros(3)), 1) 6078*da0073e9SAndroid Build Coastguard Worker self.assertTrue(str(test_bool_arith_not.graph).count('if') == 0) 6079*da0073e9SAndroid Build Coastguard Worker 6080*da0073e9SAndroid Build Coastguard Worker def test_conditional_casting(self): 6081*da0073e9SAndroid Build Coastguard Worker def test_bool_cast_tensor(x): 6082*da0073e9SAndroid Build Coastguard Worker if x: 6083*da0073e9SAndroid Build Coastguard Worker return 1 6084*da0073e9SAndroid Build Coastguard Worker else: 6085*da0073e9SAndroid Build Coastguard Worker return 0 6086*da0073e9SAndroid Build Coastguard Worker 6087*da0073e9SAndroid Build Coastguard Worker for make_one_dim in [True, False]: 6088*da0073e9SAndroid Build Coastguard Worker for inp_val in [0.1, 0.0, -0.0, -0.1, -1, 0, 1]: 6089*da0073e9SAndroid Build Coastguard Worker inp_val = [inp_val] if make_one_dim else inp_val 6090*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_bool_cast_tensor, (torch.tensor(inp_val),)) 6091*da0073e9SAndroid Build Coastguard Worker 6092*da0073e9SAndroid Build Coastguard Worker self.checkScriptRaisesRegex(test_bool_cast_tensor, (torch.tensor([1, 1]),), Exception, 6093*da0073e9SAndroid Build Coastguard Worker "Boolean value of Tensor with more than one value") 6094*da0073e9SAndroid Build Coastguard Worker 6095*da0073e9SAndroid Build Coastguard Worker def test_not_cast(x): 6096*da0073e9SAndroid Build Coastguard Worker if not x: 6097*da0073e9SAndroid Build Coastguard Worker return 1 6098*da0073e9SAndroid Build Coastguard Worker else: 6099*da0073e9SAndroid Build Coastguard Worker return 0 6100*da0073e9SAndroid Build Coastguard Worker 6101*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_not_cast, (torch.tensor(1),)) 6102*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_not_cast, (torch.tensor(0),)) 6103*da0073e9SAndroid Build Coastguard Worker 6104*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r"Could not cast value of type Tuple\[Tensor, Tensor\]"): # noqa: W605 6105*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 6106*da0073e9SAndroid Build Coastguard Worker def test_mult(x, y): 6107*da0073e9SAndroid Build Coastguard Worker return not (x, y) 6108*da0073e9SAndroid Build Coastguard Worker 6109*da0073e9SAndroid Build Coastguard Worker def test_cast_int(x): 6110*da0073e9SAndroid Build Coastguard Worker # type: (int) -> int 6111*da0073e9SAndroid Build Coastguard Worker if x: 6112*da0073e9SAndroid Build Coastguard Worker return 1 6113*da0073e9SAndroid Build Coastguard Worker else: 6114*da0073e9SAndroid Build Coastguard Worker return 0 6115*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_cast_int, (1,)) 6116*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_cast_int, (0,)) 6117*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_cast_int, (-1,)) 6118*da0073e9SAndroid Build Coastguard Worker 6119*da0073e9SAndroid Build Coastguard Worker def test_cast_float(x): 6120*da0073e9SAndroid Build Coastguard Worker # type: (float) -> int 6121*da0073e9SAndroid Build Coastguard Worker if x: 6122*da0073e9SAndroid Build Coastguard Worker return 1 6123*da0073e9SAndroid Build Coastguard Worker else: 6124*da0073e9SAndroid Build Coastguard Worker return 0 6125*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_cast_float, (1.,)) 6126*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_cast_float, (0.,)) 6127*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_cast_float, (-1.,)) 6128*da0073e9SAndroid Build Coastguard Worker 6129*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r"Could not cast value of type Tuple\[int, int\] to bool"): # noqa: W605 6130*da0073e9SAndroid Build Coastguard Worker 6131*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 6132*da0073e9SAndroid Build Coastguard Worker def test_bad_conditional(x): 6133*da0073e9SAndroid Build Coastguard Worker if (1, 2): # noqa: F634 6134*da0073e9SAndroid Build Coastguard Worker return 6135*da0073e9SAndroid Build Coastguard Worker else: 6136*da0073e9SAndroid Build Coastguard Worker return 0 6137*da0073e9SAndroid Build Coastguard Worker 6138*da0073e9SAndroid Build Coastguard Worker def test_while_nonexistent_value(self): 6139*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "undefined value x"): 6140*da0073e9SAndroid Build Coastguard Worker torch.jit.CompilationUnit(''' 6141*da0073e9SAndroid Build Coastguard Worker def test_while(a, b): 6142*da0073e9SAndroid Build Coastguard Worker while bool(a < 10): 6143*da0073e9SAndroid Build Coastguard Worker a = a + x 6144*da0073e9SAndroid Build Coastguard Worker b = b + 1 6145*da0073e9SAndroid Build Coastguard Worker return a + b 6146*da0073e9SAndroid Build Coastguard Worker ''') 6147*da0073e9SAndroid Build Coastguard Worker 6148*da0073e9SAndroid Build Coastguard Worker def test_while_nonexistent_cond_value(self): 6149*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "undefined value x"): 6150*da0073e9SAndroid Build Coastguard Worker torch.jit.CompilationUnit(''' 6151*da0073e9SAndroid Build Coastguard Worker def test_while(a, b): 6152*da0073e9SAndroid Build Coastguard Worker while a < x: 6153*da0073e9SAndroid Build Coastguard Worker a = a + 1 6154*da0073e9SAndroid Build Coastguard Worker b = b + 1 6155*da0073e9SAndroid Build Coastguard Worker return a + b 6156*da0073e9SAndroid Build Coastguard Worker ''') 6157*da0073e9SAndroid Build Coastguard Worker 6158*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 6159*da0073e9SAndroid Build Coastguard Worker def test_ternary(x): 6160*da0073e9SAndroid Build Coastguard Worker # type: (Optional[int]) -> int 6161*da0073e9SAndroid Build Coastguard Worker x = x if x is not None else 2 6162*da0073e9SAndroid Build Coastguard Worker return x 6163*da0073e9SAndroid Build Coastguard Worker 6164*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 6165*da0073e9SAndroid Build Coastguard Worker def test_not_none(x): 6166*da0073e9SAndroid Build Coastguard Worker # type: (Optional[int]) -> None 6167*da0073e9SAndroid Build Coastguard Worker if x is not None: 6168*da0073e9SAndroid Build Coastguard Worker print(x + 1) 6169*da0073e9SAndroid Build Coastguard Worker 6170*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 6171*da0073e9SAndroid Build Coastguard Worker def test_and(x, y): 6172*da0073e9SAndroid Build Coastguard Worker # type: (Optional[int], Optional[int]) -> None 6173*da0073e9SAndroid Build Coastguard Worker if x is not None and y is not None: 6174*da0073e9SAndroid Build Coastguard Worker print(x + y) 6175*da0073e9SAndroid Build Coastguard Worker 6176*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 6177*da0073e9SAndroid Build Coastguard Worker def test_not(x, y): 6178*da0073e9SAndroid Build Coastguard Worker # type: (Optional[int], Optional[int]) -> None 6179*da0073e9SAndroid Build Coastguard Worker if not (x is not None and y is not None): 6180*da0073e9SAndroid Build Coastguard Worker pass 6181*da0073e9SAndroid Build Coastguard Worker else: 6182*da0073e9SAndroid Build Coastguard Worker print(x + y) 6183*da0073e9SAndroid Build Coastguard Worker 6184*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 6185*da0073e9SAndroid Build Coastguard Worker def test_bool_expression(x): 6186*da0073e9SAndroid Build Coastguard Worker # type: (Optional[int]) -> None 6187*da0073e9SAndroid Build Coastguard Worker if x is not None and x < 2: 6188*da0073e9SAndroid Build Coastguard Worker print(x + 1) 6189*da0073e9SAndroid Build Coastguard Worker 6190*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 6191*da0073e9SAndroid Build Coastguard Worker def test_nested_bool_expression(x, y): 6192*da0073e9SAndroid Build Coastguard Worker # type: (Optional[int], Optional[int]) -> int 6193*da0073e9SAndroid Build Coastguard Worker if x is not None and x < 2 and y is not None: 6194*da0073e9SAndroid Build Coastguard Worker x = x + y 6195*da0073e9SAndroid Build Coastguard Worker else: 6196*da0073e9SAndroid Build Coastguard Worker x = 5 6197*da0073e9SAndroid Build Coastguard Worker return x + 2 6198*da0073e9SAndroid Build Coastguard Worker 6199*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 6200*da0073e9SAndroid Build Coastguard Worker def test_or(x, y): 6201*da0073e9SAndroid Build Coastguard Worker # type: (Optional[int], Optional[int]) -> None 6202*da0073e9SAndroid Build Coastguard Worker if y is None or x is None: 6203*da0073e9SAndroid Build Coastguard Worker pass 6204*da0073e9SAndroid Build Coastguard Worker else: 6205*da0073e9SAndroid Build Coastguard Worker print(x + y) 6206*da0073e9SAndroid Build Coastguard Worker 6207*da0073e9SAndroid Build Coastguard Worker # backwards compatibility 6208*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 6209*da0073e9SAndroid Build Coastguard Worker def test_manual_unwrap_opt(x): 6210*da0073e9SAndroid Build Coastguard Worker # type: (Optional[int]) -> int 6211*da0073e9SAndroid Build Coastguard Worker if x is None: 6212*da0073e9SAndroid Build Coastguard Worker x = 1 6213*da0073e9SAndroid Build Coastguard Worker else: 6214*da0073e9SAndroid Build Coastguard Worker x = torch.jit._unwrap_optional(x) 6215*da0073e9SAndroid Build Coastguard Worker return x # noqa: T484 6216*da0073e9SAndroid Build Coastguard Worker 6217*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Arguments for call are not valid"): 6218*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 6219*da0073e9SAndroid Build Coastguard Worker def or_error(x, y): 6220*da0073e9SAndroid Build Coastguard Worker # type: (Optional[int], Optional[int]) -> None 6221*da0073e9SAndroid Build Coastguard Worker if x is None or y is None: 6222*da0073e9SAndroid Build Coastguard Worker print(x + y) # noqa: T484 6223*da0073e9SAndroid Build Coastguard Worker 6224*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Arguments for call are not valid"): 6225*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 6226*da0073e9SAndroid Build Coastguard Worker def and_error(x, y): 6227*da0073e9SAndroid Build Coastguard Worker # type: (Optional[int], Optional[int]) -> None 6228*da0073e9SAndroid Build Coastguard Worker if x is None and y is None: 6229*da0073e9SAndroid Build Coastguard Worker pass 6230*da0073e9SAndroid Build Coastguard Worker else: 6231*da0073e9SAndroid Build Coastguard Worker print(x + y) # noqa: T484 6232*da0073e9SAndroid Build Coastguard Worker 6233*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Arguments for call are not valid"): 6234*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 6235*da0073e9SAndroid Build Coastguard Worker def named_var(x): 6236*da0073e9SAndroid Build Coastguard Worker # type: (Optional[int]) -> None 6237*da0073e9SAndroid Build Coastguard Worker x_none = x is not None 6238*da0073e9SAndroid Build Coastguard Worker if x_none: 6239*da0073e9SAndroid Build Coastguard Worker print(x + 1) # noqa: T484 6240*da0073e9SAndroid Build Coastguard Worker 6241*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Arguments for call are not valid"): 6242*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 6243*da0073e9SAndroid Build Coastguard Worker def named_var_and(x, y): 6244*da0073e9SAndroid Build Coastguard Worker # type: (Optional[int], Optional[int]) -> None 6245*da0073e9SAndroid Build Coastguard Worker x_none = x is not None 6246*da0073e9SAndroid Build Coastguard Worker if y is not None and x_none: 6247*da0073e9SAndroid Build Coastguard Worker print(x + y) # noqa: T484 6248*da0073e9SAndroid Build Coastguard Worker 6249*da0073e9SAndroid Build Coastguard Worker def test_assertion_optional_refinement(self): 6250*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 6251*da0073e9SAndroid Build Coastguard Worker def test(x, y): 6252*da0073e9SAndroid Build Coastguard Worker # type: (Optional[int], Optional[int]) -> int 6253*da0073e9SAndroid Build Coastguard Worker assert x is not None and y is not None 6254*da0073e9SAndroid Build Coastguard Worker return x + y 6255*da0073e9SAndroid Build Coastguard Worker 6256*da0073e9SAndroid Build Coastguard Worker self.assertEqual(test(2, 2), 4) 6257*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(Exception, ""): 6258*da0073e9SAndroid Build Coastguard Worker test(1, None) 6259*da0073e9SAndroid Build Coastguard Worker 6260*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "the current version of Profiler doesn't profile/specialize Optionals") 6261*da0073e9SAndroid Build Coastguard Worker def test_optional_tensor(self): 6262*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 6263*da0073e9SAndroid Build Coastguard Worker def fn(x, y): 6264*da0073e9SAndroid Build Coastguard Worker # type: (Optional[Tensor], int) -> int 6265*da0073e9SAndroid Build Coastguard Worker if x is None: 6266*da0073e9SAndroid Build Coastguard Worker return y 6267*da0073e9SAndroid Build Coastguard Worker else: 6268*da0073e9SAndroid Build Coastguard Worker return 0 6269*da0073e9SAndroid Build Coastguard Worker 6270*da0073e9SAndroid Build Coastguard Worker res = fn(None, 1) 6271*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, 1) 6272*da0073e9SAndroid Build Coastguard Worker g = torch.jit.last_executed_optimized_graph() 6273*da0073e9SAndroid Build Coastguard Worker first_input = next(g.inputs()) 6274*da0073e9SAndroid Build Coastguard Worker # check if input is disconnected 6275*da0073e9SAndroid Build Coastguard Worker self.assertEqual(first_input.type().kind(), 'OptionalType') 6276*da0073e9SAndroid Build Coastguard Worker self.assertEqual(first_input.uses(), []) 6277*da0073e9SAndroid Build Coastguard Worker t = torch.ones(1) 6278*da0073e9SAndroid Build Coastguard Worker res = fn(t, 1) 6279*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, 0) 6280*da0073e9SAndroid Build Coastguard Worker g = torch.jit.last_executed_optimized_graph() 6281*da0073e9SAndroid Build Coastguard Worker self.assertEqual(next(g.inputs()).type().kind(), 'TensorType') 6282*da0073e9SAndroid Build Coastguard Worker 6283*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 6284*da0073e9SAndroid Build Coastguard Worker def fn(x, y, b): 6285*da0073e9SAndroid Build Coastguard Worker # type: (Optional[Tensor], Tensor, bool) -> Tensor 6286*da0073e9SAndroid Build Coastguard Worker if b: 6287*da0073e9SAndroid Build Coastguard Worker res = y 6288*da0073e9SAndroid Build Coastguard Worker else: 6289*da0073e9SAndroid Build Coastguard Worker res = torch.jit._unwrap_optional(x) 6290*da0073e9SAndroid Build Coastguard Worker return res 6291*da0073e9SAndroid Build Coastguard Worker 6292*da0073e9SAndroid Build Coastguard Worker t2 = torch.zeros(1) 6293*da0073e9SAndroid Build Coastguard Worker res = fn(t, t2, True) 6294*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, t2) 6295*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Unwrapping null optional"): 6296*da0073e9SAndroid Build Coastguard Worker res = fn(None, t2, False) 6297*da0073e9SAndroid Build Coastguard Worker res = fn(None, t2, True) 6298*da0073e9SAndroid Build Coastguard Worker g = torch.jit.last_executed_optimized_graph() 6299*da0073e9SAndroid Build Coastguard Worker self.assertIn(next(g.outputs()).type().str(), ("Tensor", "Tensor(requires_grad=1)")) 6300*da0073e9SAndroid Build Coastguard Worker 6301*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "the current version of Profiler doesn't profile/specialize Optionals") 6302*da0073e9SAndroid Build Coastguard Worker def test_optional_list(self): 6303*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 6304*da0073e9SAndroid Build Coastguard Worker def fn(x, y): 6305*da0073e9SAndroid Build Coastguard Worker # type: (Optional[List[int]], int) -> int 6306*da0073e9SAndroid Build Coastguard Worker if x is None: 6307*da0073e9SAndroid Build Coastguard Worker return y 6308*da0073e9SAndroid Build Coastguard Worker else: 6309*da0073e9SAndroid Build Coastguard Worker res = 0 6310*da0073e9SAndroid Build Coastguard Worker for d in x: 6311*da0073e9SAndroid Build Coastguard Worker res += d 6312*da0073e9SAndroid Build Coastguard Worker return res 6313*da0073e9SAndroid Build Coastguard Worker 6314*da0073e9SAndroid Build Coastguard Worker res = fn(None, 1) 6315*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, 1) 6316*da0073e9SAndroid Build Coastguard Worker g = torch.jit.last_executed_optimized_graph() 6317*da0073e9SAndroid Build Coastguard Worker first_input = next(g.inputs()) 6318*da0073e9SAndroid Build Coastguard Worker # check if input is disconnected 6319*da0073e9SAndroid Build Coastguard Worker self.assertEqual(first_input.type().kind(), 'OptionalType') 6320*da0073e9SAndroid Build Coastguard Worker self.assertEqual(first_input.uses(), []) 6321*da0073e9SAndroid Build Coastguard Worker l = [2, 3] 6322*da0073e9SAndroid Build Coastguard Worker res = fn(l, 1) 6323*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, 5) 6324*da0073e9SAndroid Build Coastguard Worker g = torch.jit.last_executed_optimized_graph() 6325*da0073e9SAndroid Build Coastguard Worker self.assertEqual(next(g.inputs()).type().kind(), 'ListType') 6326*da0073e9SAndroid Build Coastguard Worker 6327*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 6328*da0073e9SAndroid Build Coastguard Worker def fn(x, y, b): 6329*da0073e9SAndroid Build Coastguard Worker # type: (Optional[List[int]], List[int], bool) -> List[int] 6330*da0073e9SAndroid Build Coastguard Worker if b: 6331*da0073e9SAndroid Build Coastguard Worker l = torch.jit._unwrap_optional(x) 6332*da0073e9SAndroid Build Coastguard Worker else: 6333*da0073e9SAndroid Build Coastguard Worker l = y 6334*da0073e9SAndroid Build Coastguard Worker return l 6335*da0073e9SAndroid Build Coastguard Worker 6336*da0073e9SAndroid Build Coastguard Worker l2 = [0, 1] 6337*da0073e9SAndroid Build Coastguard Worker res = fn(l, l2, True) 6338*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, l) 6339*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Unwrapping null optional"): 6340*da0073e9SAndroid Build Coastguard Worker res = fn(None, l2, True) 6341*da0073e9SAndroid Build Coastguard Worker res = fn(None, l2, False) 6342*da0073e9SAndroid Build Coastguard Worker g = torch.jit.last_executed_optimized_graph() 6343*da0073e9SAndroid Build Coastguard Worker self.assertEqual(next(g.outputs()).type().str(), "int[]") 6344*da0073e9SAndroid Build Coastguard Worker 6345*da0073e9SAndroid Build Coastguard Worker def test_alias_covariant_type_containers(self): 6346*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 6347*da0073e9SAndroid Build Coastguard Worker def foo(x): 6348*da0073e9SAndroid Build Coastguard Worker # type: (bool) 6349*da0073e9SAndroid Build Coastguard Worker if x: 6350*da0073e9SAndroid Build Coastguard Worker a = (None,) 6351*da0073e9SAndroid Build Coastguard Worker else: 6352*da0073e9SAndroid Build Coastguard Worker a = ([],) 6353*da0073e9SAndroid Build Coastguard Worker return a 6354*da0073e9SAndroid Build Coastguard Worker 6355*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 6356*da0073e9SAndroid Build Coastguard Worker def foo2(x, li): 6357*da0073e9SAndroid Build Coastguard Worker # type: (bool, Tuple[Optional[List[Tensor]]]) 6358*da0073e9SAndroid Build Coastguard Worker if x: 6359*da0073e9SAndroid Build Coastguard Worker li = (None,) 6360*da0073e9SAndroid Build Coastguard Worker return li 6361*da0073e9SAndroid Build Coastguard Worker 6362*da0073e9SAndroid Build Coastguard Worker def test_while_write_outer_then_read(self): 6363*da0073e9SAndroid Build Coastguard Worker def func(a, b): 6364*da0073e9SAndroid Build Coastguard Worker while bool(a < 10): 6365*da0073e9SAndroid Build Coastguard Worker a = a + 1 6366*da0073e9SAndroid Build Coastguard Worker b = a + 1 6367*da0073e9SAndroid Build Coastguard Worker return a + b 6368*da0073e9SAndroid Build Coastguard Worker 6369*da0073e9SAndroid Build Coastguard Worker inputs = self._make_scalar_vars([42, 1337], torch.int64) 6370*da0073e9SAndroid Build Coastguard Worker self.checkScript(func, inputs, optimize=True) 6371*da0073e9SAndroid Build Coastguard Worker 6372*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("TorchDynamo fails with unknown reason") 6373*da0073e9SAndroid Build Coastguard Worker def test_while_nest_if(self): 6374*da0073e9SAndroid Build Coastguard Worker def func(a, b): 6375*da0073e9SAndroid Build Coastguard Worker # type: (int, int) -> int 6376*da0073e9SAndroid Build Coastguard Worker c = 0 6377*da0073e9SAndroid Build Coastguard Worker while a < 10: 6378*da0073e9SAndroid Build Coastguard Worker a = a + 1 6379*da0073e9SAndroid Build Coastguard Worker b = b + 1 6380*da0073e9SAndroid Build Coastguard Worker if a > b: 6381*da0073e9SAndroid Build Coastguard Worker c = -a 6382*da0073e9SAndroid Build Coastguard Worker else: 6383*da0073e9SAndroid Build Coastguard Worker c = -b 6384*da0073e9SAndroid Build Coastguard Worker return c + 1 6385*da0073e9SAndroid Build Coastguard Worker 6386*da0073e9SAndroid Build Coastguard Worker inputs = self._make_scalar_vars([-1234, 4321], torch.int64) 6387*da0073e9SAndroid Build Coastguard Worker self.checkScript(func, inputs, optimize=True) 6388*da0073e9SAndroid Build Coastguard Worker 6389*da0073e9SAndroid Build Coastguard Worker def test_divmod(self): 6390*da0073e9SAndroid Build Coastguard Worker def func_int(a, b): 6391*da0073e9SAndroid Build Coastguard Worker # type: (int, int) -> Tuple[int, int] 6392*da0073e9SAndroid Build Coastguard Worker return divmod(a, b) 6393*da0073e9SAndroid Build Coastguard Worker 6394*da0073e9SAndroid Build Coastguard Worker def func_float(a, b): 6395*da0073e9SAndroid Build Coastguard Worker # type: (float, float) -> Tuple[float, float] 6396*da0073e9SAndroid Build Coastguard Worker return divmod(a, b) 6397*da0073e9SAndroid Build Coastguard Worker 6398*da0073e9SAndroid Build Coastguard Worker def func_int_float(a, b): 6399*da0073e9SAndroid Build Coastguard Worker # type: (int, float) -> Tuple[float, float] 6400*da0073e9SAndroid Build Coastguard Worker return divmod(a, b) 6401*da0073e9SAndroid Build Coastguard Worker 6402*da0073e9SAndroid Build Coastguard Worker def func_float_int(a, b): 6403*da0073e9SAndroid Build Coastguard Worker # type: (float, int) -> Tuple[float, float] 6404*da0073e9SAndroid Build Coastguard Worker return divmod(a, b) 6405*da0073e9SAndroid Build Coastguard Worker 6406*da0073e9SAndroid Build Coastguard Worker def divmod_test_iterator(func, num, den): 6407*da0073e9SAndroid Build Coastguard Worker for i in num: 6408*da0073e9SAndroid Build Coastguard Worker for j in den: 6409*da0073e9SAndroid Build Coastguard Worker self.checkScript(func, (i, j), frames_up=2) 6410*da0073e9SAndroid Build Coastguard Worker 6411*da0073e9SAndroid Build Coastguard Worker num_int = [1024, -1024] 6412*da0073e9SAndroid Build Coastguard Worker den_int = [10, -10] 6413*da0073e9SAndroid Build Coastguard Worker num_float = [5.3, -5.3] 6414*da0073e9SAndroid Build Coastguard Worker den_float = [2.0, -2.0] 6415*da0073e9SAndroid Build Coastguard Worker divmod_test_iterator(func_int, num_int, den_int) 6416*da0073e9SAndroid Build Coastguard Worker divmod_test_iterator(func_float, num_float, den_float) 6417*da0073e9SAndroid Build Coastguard Worker divmod_test_iterator(func_int_float, num_int, den_float) 6418*da0073e9SAndroid Build Coastguard Worker divmod_test_iterator(func_float_int, num_float, den_int) 6419*da0073e9SAndroid Build Coastguard Worker 6420*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "ZeroDivisionError: integer division or modulo by zero"): 6421*da0073e9SAndroid Build Coastguard Worker cu = torch.jit.CompilationUnit(dedent(inspect.getsource(func_int))) 6422*da0073e9SAndroid Build Coastguard Worker cu.func_int(1024, 0) 6423*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "ZeroDivisionError: float divmod()"): 6424*da0073e9SAndroid Build Coastguard Worker cu = torch.jit.CompilationUnit(dedent(inspect.getsource(func_float))) 6425*da0073e9SAndroid Build Coastguard Worker cu.func_float(5.3, 0.0) 6426*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "ZeroDivisionError: float divmod()"): 6427*da0073e9SAndroid Build Coastguard Worker cu = torch.jit.CompilationUnit(dedent(inspect.getsource(func_int_float))) 6428*da0073e9SAndroid Build Coastguard Worker cu.func_int_float(1024, 0.0) 6429*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "ZeroDivisionError: float divmod()"): 6430*da0073e9SAndroid Build Coastguard Worker cu = torch.jit.CompilationUnit(dedent(inspect.getsource(func_float_int))) 6431*da0073e9SAndroid Build Coastguard Worker cu.func_float_int(5.3, 0) 6432*da0073e9SAndroid Build Coastguard Worker 6433*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("Not a TorchDynamo suitable test") 6434*da0073e9SAndroid Build Coastguard Worker def test_math_ops(self): 6435*da0073e9SAndroid Build Coastguard Worker def checkMathWrap(func_name, num_args=1, is_float=True, **args): 6436*da0073e9SAndroid Build Coastguard Worker if is_float: 6437*da0073e9SAndroid Build Coastguard Worker checkMath(func_name, num_args, True, **args) 6438*da0073e9SAndroid Build Coastguard Worker checkMath(func_name, num_args, False, **args) 6439*da0073e9SAndroid Build Coastguard Worker else: 6440*da0073e9SAndroid Build Coastguard Worker checkMath(func_name, num_args, is_float, **args) 6441*da0073e9SAndroid Build Coastguard Worker 6442*da0073e9SAndroid Build Coastguard Worker inf = float("inf") 6443*da0073e9SAndroid Build Coastguard Worker NaN = float("nan") 6444*da0073e9SAndroid Build Coastguard Worker mx_int = 2**31 - 1 6445*da0073e9SAndroid Build Coastguard Worker mn_int = -2**31 6446*da0073e9SAndroid Build Coastguard Worker float_vals = ([inf, NaN, 0.0, 1.0, 2.2, -1.0, -0.0, -2.2, -inf, 1, 0, 2] + 6447*da0073e9SAndroid Build Coastguard Worker [10.0 ** i for i in range(5)] + [-(10.0 ** i) for i in range(5)]) 6448*da0073e9SAndroid Build Coastguard Worker int_vals = list(range(-5, 5, 1)) + [mx_int + 5, mx_int * 2, mn_int - 5, mn_int * 2] 6449*da0073e9SAndroid Build Coastguard Worker 6450*da0073e9SAndroid Build Coastguard Worker def checkMath(func_name, num_args, is_float=True, ret_type="float", debug=False, vals=None, args_type=None): 6451*da0073e9SAndroid Build Coastguard Worker funcs_template = dedent(''' 6452*da0073e9SAndroid Build Coastguard Worker def func(a, b): 6453*da0073e9SAndroid Build Coastguard Worker # type: {args_type} -> {ret_type} 6454*da0073e9SAndroid Build Coastguard Worker return math.{func}({args}) 6455*da0073e9SAndroid Build Coastguard Worker ''') 6456*da0073e9SAndroid Build Coastguard Worker if num_args == 1: 6457*da0073e9SAndroid Build Coastguard Worker args = "a" 6458*da0073e9SAndroid Build Coastguard Worker elif num_args == 2: 6459*da0073e9SAndroid Build Coastguard Worker args = "a, b" 6460*da0073e9SAndroid Build Coastguard Worker else: 6461*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("Test doesn't support more than 2 arguments") 6462*da0073e9SAndroid Build Coastguard Worker if args_type is None: 6463*da0073e9SAndroid Build Coastguard Worker args_type = "(float, float)" if is_float else "(int, int)" 6464*da0073e9SAndroid Build Coastguard Worker funcs_str = funcs_template.format(func=func_name, args=args, args_type=args_type, ret_type=ret_type) 6465*da0073e9SAndroid Build Coastguard Worker scope = {} 6466*da0073e9SAndroid Build Coastguard Worker execWrapper(funcs_str, globals(), scope) 6467*da0073e9SAndroid Build Coastguard Worker cu = torch.jit.CompilationUnit(funcs_str) 6468*da0073e9SAndroid Build Coastguard Worker f_script = cu.func 6469*da0073e9SAndroid Build Coastguard Worker f = scope['func'] 6470*da0073e9SAndroid Build Coastguard Worker 6471*da0073e9SAndroid Build Coastguard Worker if vals is None: 6472*da0073e9SAndroid Build Coastguard Worker vals = float_vals if is_float else int_vals 6473*da0073e9SAndroid Build Coastguard Worker vals = [(i, j) for i in vals for j in vals] 6474*da0073e9SAndroid Build Coastguard Worker 6475*da0073e9SAndroid Build Coastguard Worker for a, b in vals: 6476*da0073e9SAndroid Build Coastguard Worker res_python = None 6477*da0073e9SAndroid Build Coastguard Worker res_script = None 6478*da0073e9SAndroid Build Coastguard Worker try: 6479*da0073e9SAndroid Build Coastguard Worker res_python = f(a, b) 6480*da0073e9SAndroid Build Coastguard Worker except Exception as e: 6481*da0073e9SAndroid Build Coastguard Worker res_python = e 6482*da0073e9SAndroid Build Coastguard Worker try: 6483*da0073e9SAndroid Build Coastguard Worker res_script = f_script(a, b) 6484*da0073e9SAndroid Build Coastguard Worker except Exception as e: 6485*da0073e9SAndroid Build Coastguard Worker res_script = e 6486*da0073e9SAndroid Build Coastguard Worker if debug: 6487*da0073e9SAndroid Build Coastguard Worker print("in: ", a, b) 6488*da0073e9SAndroid Build Coastguard Worker print("out: ", res_python, res_script) 6489*da0073e9SAndroid Build Coastguard Worker # We can't use assertEqual because of a couple of differences: 6490*da0073e9SAndroid Build Coastguard Worker # 1. nan == nan should return true 6491*da0073e9SAndroid Build Coastguard Worker # 2. When python functions throw an exception, we usually want to silently ignore them. 6492*da0073e9SAndroid Build Coastguard Worker # (ie: We want to return `nan` for math.sqrt(-5)) 6493*da0073e9SAndroid Build Coastguard Worker if res_python != res_script: 6494*da0073e9SAndroid Build Coastguard Worker if isinstance(res_python, Exception): 6495*da0073e9SAndroid Build Coastguard Worker continue 6496*da0073e9SAndroid Build Coastguard Worker 6497*da0073e9SAndroid Build Coastguard Worker if type(res_python) == type(res_script): 6498*da0073e9SAndroid Build Coastguard Worker if isinstance(res_python, tuple) and (math.isnan(res_python[0]) == math.isnan(res_script[0])): 6499*da0073e9SAndroid Build Coastguard Worker continue 6500*da0073e9SAndroid Build Coastguard Worker if isinstance(res_python, float) and math.isnan(res_python) and math.isnan(res_script): 6501*da0073e9SAndroid Build Coastguard Worker continue 6502*da0073e9SAndroid Build Coastguard Worker msg = (f"Failed on {func_name} with inputs {a} {b}. Python: {res_python}, Script: {res_script}") 6503*da0073e9SAndroid Build Coastguard Worker # math.pow() behavior has changed in 3.11, see https://docs.python.org/3/library/math.html#math.pow 6504*da0073e9SAndroid Build Coastguard Worker if sys.version_info >= (3, 11) and func_name == "pow" and a == 0.0 and b == -math.inf: 6505*da0073e9SAndroid Build Coastguard Worker self.assertTrue(res_python == math.inf and type(res_script) is RuntimeError) 6506*da0073e9SAndroid Build Coastguard Worker else: 6507*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res_python, res_script, msg=msg, atol=(1e-4) * max(abs(res_python), res_script), rtol=0) 6508*da0073e9SAndroid Build Coastguard Worker 6509*da0073e9SAndroid Build Coastguard Worker unary_float_ops = ["log", "log1p", "log10", "exp", "sqrt", "gamma", "lgamma", "erf", 6510*da0073e9SAndroid Build Coastguard Worker "erfc", "expm1", "fabs", "acos", "asin", "atan", "cos", "sin", "tan", 6511*da0073e9SAndroid Build Coastguard Worker "asinh", "atanh", "acosh", "sinh", "cosh", "tanh", "degrees", "radians"] 6512*da0073e9SAndroid Build Coastguard Worker binary_float_ops = ["atan2", "fmod", "copysign"] 6513*da0073e9SAndroid Build Coastguard Worker for op in unary_float_ops: 6514*da0073e9SAndroid Build Coastguard Worker checkMathWrap(op, 1) 6515*da0073e9SAndroid Build Coastguard Worker for op in binary_float_ops: 6516*da0073e9SAndroid Build Coastguard Worker checkMathWrap(op, 2) 6517*da0073e9SAndroid Build Coastguard Worker 6518*da0073e9SAndroid Build Coastguard Worker checkMath("modf", 1, ret_type="Tuple[float, float]") 6519*da0073e9SAndroid Build Coastguard Worker checkMath("frexp", 1, ret_type="Tuple[float, int]") 6520*da0073e9SAndroid Build Coastguard Worker checkMath("isnan", 1, ret_type="bool") 6521*da0073e9SAndroid Build Coastguard Worker checkMath("isinf", 1, ret_type="bool") 6522*da0073e9SAndroid Build Coastguard Worker checkMath("ldexp", 2, is_float=False, ret_type="float", args_type="(float, int)", 6523*da0073e9SAndroid Build Coastguard Worker vals=[(i, j) for i in float_vals for j in range(-10, 10)]) 6524*da0073e9SAndroid Build Coastguard Worker checkMath("pow", 2, is_float=False, ret_type="float") 6525*da0073e9SAndroid Build Coastguard Worker checkMath("pow", 2, is_float=True, ret_type="float") 6526*da0073e9SAndroid Build Coastguard Worker checkMathWrap("floor", ret_type="int") 6527*da0073e9SAndroid Build Coastguard Worker checkMathWrap("ceil", ret_type="int") 6528*da0073e9SAndroid Build Coastguard Worker checkMathWrap("gcd", 2, is_float=False, ret_type="int") 6529*da0073e9SAndroid Build Coastguard Worker checkMath("isfinite", 1, ret_type="bool") 6530*da0073e9SAndroid Build Coastguard Worker checkMathWrap("remainder", 2) 6531*da0073e9SAndroid Build Coastguard Worker checkMathWrap("factorial", 1, is_float=False, ret_type="int", vals=[(i, 0) for i in range(-2, 10)]) 6532*da0073e9SAndroid Build Coastguard Worker 6533*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("TorchDynamo fails with unknown reason") 6534*da0073e9SAndroid Build Coastguard Worker def test_if_nest_while(self): 6535*da0073e9SAndroid Build Coastguard Worker def func(a, b): 6536*da0073e9SAndroid Build Coastguard Worker # type: (int, int) -> int 6537*da0073e9SAndroid Build Coastguard Worker c = 0 6538*da0073e9SAndroid Build Coastguard Worker if a > b: 6539*da0073e9SAndroid Build Coastguard Worker while a > b: 6540*da0073e9SAndroid Build Coastguard Worker b = b + 1 6541*da0073e9SAndroid Build Coastguard Worker c = -b 6542*da0073e9SAndroid Build Coastguard Worker return c 6543*da0073e9SAndroid Build Coastguard Worker 6544*da0073e9SAndroid Build Coastguard Worker inputs = self._make_scalar_vars([4321, 1234], torch.int64) 6545*da0073e9SAndroid Build Coastguard Worker self.checkScript(func, inputs) 6546*da0073e9SAndroid Build Coastguard Worker 6547*da0073e9SAndroid Build Coastguard Worker def test_script_optional_none(self): 6548*da0073e9SAndroid Build Coastguard Worker def none_stmt(x): 6549*da0073e9SAndroid Build Coastguard Worker output = None 6550*da0073e9SAndroid Build Coastguard Worker output = x 6551*da0073e9SAndroid Build Coastguard Worker return output 6552*da0073e9SAndroid Build Coastguard Worker 6553*da0073e9SAndroid Build Coastguard Worker def none_args(x): 6554*da0073e9SAndroid Build Coastguard Worker # type: (Optional[Tensor]) -> Optional[Tensor] 6555*da0073e9SAndroid Build Coastguard Worker return None 6556*da0073e9SAndroid Build Coastguard Worker 6557*da0073e9SAndroid Build Coastguard Worker self.checkScript(none_stmt, [torch.arange(0, 2)], optimize=True) 6558*da0073e9SAndroid Build Coastguard Worker self.checkScript(none_args, [None], optimize=True) 6559*da0073e9SAndroid Build Coastguard Worker 6560*da0073e9SAndroid Build Coastguard Worker # test undefined tensor None as default param 6561*da0073e9SAndroid Build Coastguard Worker def test_script_optional_tensor_none(x=None): 6562*da0073e9SAndroid Build Coastguard Worker # type: (Optional[Tensor]) -> Tensor 6563*da0073e9SAndroid Build Coastguard Worker res = torch.zeros(1, dtype=torch.int8) 6564*da0073e9SAndroid Build Coastguard Worker if x is None: 6565*da0073e9SAndroid Build Coastguard Worker res = res + 1 6566*da0073e9SAndroid Build Coastguard Worker else: 6567*da0073e9SAndroid Build Coastguard Worker res = x 6568*da0073e9SAndroid Build Coastguard Worker return res 6569*da0073e9SAndroid Build Coastguard Worker 6570*da0073e9SAndroid Build Coastguard Worker fn = test_script_optional_tensor_none 6571*da0073e9SAndroid Build Coastguard Worker scripted_fn = torch.jit.script(fn) 6572*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn(), scripted_fn()) 6573*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn(torch.zeros(1)), scripted_fn(torch.zeros(1))) 6574*da0073e9SAndroid Build Coastguard Worker 6575*da0073e9SAndroid Build Coastguard Worker # test typical None as default param 6576*da0073e9SAndroid Build Coastguard Worker def test_script_optional_other_none(x=None): 6577*da0073e9SAndroid Build Coastguard Worker # type: (Optional[float]) -> float 6578*da0073e9SAndroid Build Coastguard Worker res = 2.0 6579*da0073e9SAndroid Build Coastguard Worker if x is None: 6580*da0073e9SAndroid Build Coastguard Worker res = res + 1.0 6581*da0073e9SAndroid Build Coastguard Worker else: 6582*da0073e9SAndroid Build Coastguard Worker res = x 6583*da0073e9SAndroid Build Coastguard Worker return res 6584*da0073e9SAndroid Build Coastguard Worker 6585*da0073e9SAndroid Build Coastguard Worker fn = test_script_optional_other_none 6586*da0073e9SAndroid Build Coastguard Worker scripted_fn = torch.jit.script(fn) 6587*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn(), scripted_fn()) 6588*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn(1.0), scripted_fn(1.0)) 6589*da0073e9SAndroid Build Coastguard Worker 6590*da0073e9SAndroid Build Coastguard Worker def test_script_clamp_none(self): 6591*da0073e9SAndroid Build Coastguard Worker def test_script_clamp_max_none(x): 6592*da0073e9SAndroid Build Coastguard Worker return torch.clamp(x, min=2, max=None) 6593*da0073e9SAndroid Build Coastguard Worker 6594*da0073e9SAndroid Build Coastguard Worker def test_script_clamp_max(x): 6595*da0073e9SAndroid Build Coastguard Worker return torch.clamp(x, max=2) 6596*da0073e9SAndroid Build Coastguard Worker 6597*da0073e9SAndroid Build Coastguard Worker def test_script_clamp_min_none(x): 6598*da0073e9SAndroid Build Coastguard Worker return torch.clamp(x, min=None, max=2) 6599*da0073e9SAndroid Build Coastguard Worker 6600*da0073e9SAndroid Build Coastguard Worker def test_script_clamp_min(x): 6601*da0073e9SAndroid Build Coastguard Worker return torch.clamp(x, min=2) 6602*da0073e9SAndroid Build Coastguard Worker 6603*da0073e9SAndroid Build Coastguard Worker input = [torch.arange(0, 3)] 6604*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_script_clamp_max_none, input, optimize=True) 6605*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_script_clamp_max, input, optimize=True) 6606*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_script_clamp_min_none, input, optimize=True) 6607*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_script_clamp_min, input, optimize=True) 6608*da0073e9SAndroid Build Coastguard Worker 6609*da0073e9SAndroid Build Coastguard Worker def test_script_bool_constant(self): 6610*da0073e9SAndroid Build Coastguard Worker def test_script_bool_constant(): 6611*da0073e9SAndroid Build Coastguard Worker a = True 6612*da0073e9SAndroid Build Coastguard Worker return a 6613*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_script_bool_constant, []) 6614*da0073e9SAndroid Build Coastguard Worker 6615*da0073e9SAndroid Build Coastguard Worker def test_ternary(self): 6616*da0073e9SAndroid Build Coastguard Worker def func(a, b): 6617*da0073e9SAndroid Build Coastguard Worker c = 3 6618*da0073e9SAndroid Build Coastguard Worker c = a + b if bool(a > 3) else b 6619*da0073e9SAndroid Build Coastguard Worker return c 6620*da0073e9SAndroid Build Coastguard Worker 6621*da0073e9SAndroid Build Coastguard Worker inputs_true = self._make_scalar_vars([5, 2], torch.int64) 6622*da0073e9SAndroid Build Coastguard Worker inputs_false = self._make_scalar_vars([1, 0], torch.int64) 6623*da0073e9SAndroid Build Coastguard Worker self.checkScript(func, inputs_true, optimize=True) 6624*da0073e9SAndroid Build Coastguard Worker self.checkScript(func, inputs_false, optimize=True) 6625*da0073e9SAndroid Build Coastguard Worker 6626*da0073e9SAndroid Build Coastguard Worker def test_ternary_module_type_hint(self): 6627*da0073e9SAndroid Build Coastguard Worker class M1(torch.nn.Module): 6628*da0073e9SAndroid Build Coastguard Worker def forward(self) -> Any: 6629*da0073e9SAndroid Build Coastguard Worker return 'out' if self.training else {} 6630*da0073e9SAndroid Build Coastguard Worker 6631*da0073e9SAndroid Build Coastguard Worker class M2(torch.nn.Module): 6632*da0073e9SAndroid Build Coastguard Worker def forward(self) -> Any: 6633*da0073e9SAndroid Build Coastguard Worker out: Any = 'out' if self.training else {} 6634*da0073e9SAndroid Build Coastguard Worker return out 6635*da0073e9SAndroid Build Coastguard Worker 6636*da0073e9SAndroid Build Coastguard Worker class M3(torch.nn.Module): 6637*da0073e9SAndroid Build Coastguard Worker def forward(self) -> Optional[int]: 6638*da0073e9SAndroid Build Coastguard Worker return None if self.training else 1 6639*da0073e9SAndroid Build Coastguard Worker 6640*da0073e9SAndroid Build Coastguard Worker for module in [M1, M2, M3]: 6641*da0073e9SAndroid Build Coastguard Worker self.checkModule(module().train(), ()) 6642*da0073e9SAndroid Build Coastguard Worker self.checkModule(module().eval(), ()) 6643*da0073e9SAndroid Build Coastguard Worker 6644*da0073e9SAndroid Build Coastguard Worker def test_ternary_static_if(self): 6645*da0073e9SAndroid Build Coastguard Worker # Test for True branch when condition variable 6646*da0073e9SAndroid Build Coastguard Worker # is annotated as Final 6647*da0073e9SAndroid Build Coastguard Worker class M1(torch.nn.Module): 6648*da0073e9SAndroid Build Coastguard Worker flag: torch.jit.Final[bool] 6649*da0073e9SAndroid Build Coastguard Worker 6650*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 6651*da0073e9SAndroid Build Coastguard Worker super().__init__() 6652*da0073e9SAndroid Build Coastguard Worker self.flag = True 6653*da0073e9SAndroid Build Coastguard Worker 6654*da0073e9SAndroid Build Coastguard Worker def forward(self) -> torch.Tensor: 6655*da0073e9SAndroid Build Coastguard Worker return torch.ones(3) if self.flag else {} 6656*da0073e9SAndroid Build Coastguard Worker 6657*da0073e9SAndroid Build Coastguard Worker # Test for True branch when condition variable 6658*da0073e9SAndroid Build Coastguard Worker # is annotated as Final 6659*da0073e9SAndroid Build Coastguard Worker class M2(torch.nn.Module): 6660*da0073e9SAndroid Build Coastguard Worker flag: torch.jit.Final[bool] 6661*da0073e9SAndroid Build Coastguard Worker 6662*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 6663*da0073e9SAndroid Build Coastguard Worker super().__init__() 6664*da0073e9SAndroid Build Coastguard Worker self.flag = False 6665*da0073e9SAndroid Build Coastguard Worker 6666*da0073e9SAndroid Build Coastguard Worker def forward(self) -> torch.Tensor: 6667*da0073e9SAndroid Build Coastguard Worker return {} if self.flag else torch.ones(3) 6668*da0073e9SAndroid Build Coastguard Worker 6669*da0073e9SAndroid Build Coastguard Worker model1 = M1() 6670*da0073e9SAndroid Build Coastguard Worker model2 = M2() 6671*da0073e9SAndroid Build Coastguard Worker script_model_1 = torch.jit.script(model1) 6672*da0073e9SAndroid Build Coastguard Worker script_model_2 = torch.jit.script(model2) 6673*da0073e9SAndroid Build Coastguard Worker self.assertEqual(model1.forward(), script_model_1.forward()) 6674*da0073e9SAndroid Build Coastguard Worker self.assertEqual(model2.forward(), script_model_2.forward()) 6675*da0073e9SAndroid Build Coastguard Worker 6676*da0073e9SAndroid Build Coastguard Worker def test_ternary_right_associative(self): 6677*da0073e9SAndroid Build Coastguard Worker def plus_123(x: int): 6678*da0073e9SAndroid Build Coastguard Worker return x + 1 if x == 1 else x + 2 if x == 2 else x + 3 6679*da0073e9SAndroid Build Coastguard Worker self.checkScript(plus_123, (1,)) 6680*da0073e9SAndroid Build Coastguard Worker self.checkScript(plus_123, (2,)) 6681*da0073e9SAndroid Build Coastguard Worker self.checkScript(plus_123, (3,)) 6682*da0073e9SAndroid Build Coastguard Worker 6683*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("TorchDynamo fails with unknown reason") 6684*da0073e9SAndroid Build Coastguard Worker def test_print(self): 6685*da0073e9SAndroid Build Coastguard Worker def func(x, y): 6686*da0073e9SAndroid Build Coastguard Worker q = (x + y).sigmoid() 6687*da0073e9SAndroid Build Coastguard Worker print(q, 1, 2, [1, 2], [1.0, 2.0]) 6688*da0073e9SAndroid Build Coastguard Worker w = -q 6689*da0073e9SAndroid Build Coastguard Worker return w * w 6690*da0073e9SAndroid Build Coastguard Worker 6691*da0073e9SAndroid Build Coastguard Worker x = torch.arange(4., requires_grad=True) 6692*da0073e9SAndroid Build Coastguard Worker y = torch.arange(0., 8, 2, requires_grad=True) 6693*da0073e9SAndroid Build Coastguard Worker self.checkScript(func, [x, y], optimize=True, capture_output=True) 6694*da0073e9SAndroid Build Coastguard Worker 6695*da0073e9SAndroid Build Coastguard Worker def test_format(self): 6696*da0073e9SAndroid Build Coastguard Worker def func(x): 6697*da0073e9SAndroid Build Coastguard Worker print("{}, I'm a {}".format("Hello", "test")) 6698*da0073e9SAndroid Build Coastguard Worker print("format blank".format()) 6699*da0073e9SAndroid Build Coastguard Worker print("stuff before {}".format("hi")) 6700*da0073e9SAndroid Build Coastguard Worker print("{} stuff after".format("hi")) 6701*da0073e9SAndroid Build Coastguard Worker return x + 1 6702*da0073e9SAndroid Build Coastguard Worker 6703*da0073e9SAndroid Build Coastguard Worker x = torch.arange(4., requires_grad=True) 6704*da0073e9SAndroid Build Coastguard Worker self.checkScript(func, [x], optimize=True, capture_output=True) 6705*da0073e9SAndroid Build Coastguard Worker 6706*da0073e9SAndroid Build Coastguard Worker def test_logical_short_circuit(self): 6707*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 6708*da0073e9SAndroid Build Coastguard Worker def testNoThrows(t): 6709*da0073e9SAndroid Build Coastguard Worker c1 = 1 6710*da0073e9SAndroid Build Coastguard Worker if (False and bool(t[1])) or (True or bool(t[1])): 6711*da0073e9SAndroid Build Coastguard Worker c1 = 0 6712*da0073e9SAndroid Build Coastguard Worker return c1 6713*da0073e9SAndroid Build Coastguard Worker 6714*da0073e9SAndroid Build Coastguard Worker FileCheck().check_not("prim::If").run(testNoThrows.graph) 6715*da0073e9SAndroid Build Coastguard Worker self.assertEqual(0, testNoThrows(torch.randn(0))) 6716*da0073e9SAndroid Build Coastguard Worker self.assertEqual(0, testNoThrows(torch.randn([2, 3]))) 6717*da0073e9SAndroid Build Coastguard Worker 6718*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 6719*da0073e9SAndroid Build Coastguard Worker def throwsOr(t): 6720*da0073e9SAndroid Build Coastguard Worker c0 = False or bool(t[1]) 6721*da0073e9SAndroid Build Coastguard Worker print(c0) 6722*da0073e9SAndroid Build Coastguard Worker 6723*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 6724*da0073e9SAndroid Build Coastguard Worker def throwsAnd(t): 6725*da0073e9SAndroid Build Coastguard Worker c0 = True and bool(t[1]) 6726*da0073e9SAndroid Build Coastguard Worker print(c0) 6727*da0073e9SAndroid Build Coastguard Worker 6728*da0073e9SAndroid Build Coastguard Worker t = torch.randn(0) 6729*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "index 1 out of range for tensor of size"): 6730*da0073e9SAndroid Build Coastguard Worker throwsOr(t) 6731*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "index 1 out of range for tensor of size"): 6732*da0073e9SAndroid Build Coastguard Worker throwsAnd(t) 6733*da0073e9SAndroid Build Coastguard Worker 6734*da0073e9SAndroid Build Coastguard Worker def test_type_cast(self): 6735*da0073e9SAndroid Build Coastguard Worker template = dedent(''' 6736*da0073e9SAndroid Build Coastguard Worker def func(v): 6737*da0073e9SAndroid Build Coastguard Worker # type: ({from_type}) -> {to_type} 6738*da0073e9SAndroid Build Coastguard Worker return {to_type}(v) 6739*da0073e9SAndroid Build Coastguard Worker ''') 6740*da0073e9SAndroid Build Coastguard Worker 6741*da0073e9SAndroid Build Coastguard Worker def check_cast(from_type, to_type, value, raises=False): 6742*da0073e9SAndroid Build Coastguard Worker code = template.format(from_type=from_type, to_type=to_type) 6743*da0073e9SAndroid Build Coastguard Worker self.checkScript(code, (value,)) 6744*da0073e9SAndroid Build Coastguard Worker 6745*da0073e9SAndroid Build Coastguard Worker check_cast('int', 'float', 1) 6746*da0073e9SAndroid Build Coastguard Worker check_cast('int', 'bool', 1) 6747*da0073e9SAndroid Build Coastguard Worker check_cast('int', 'bool', 0) 6748*da0073e9SAndroid Build Coastguard Worker 6749*da0073e9SAndroid Build Coastguard Worker check_cast('float', 'int', 1.) 6750*da0073e9SAndroid Build Coastguard Worker check_cast('float', 'bool', 1.) 6751*da0073e9SAndroid Build Coastguard Worker check_cast('float', 'bool', 0.) 6752*da0073e9SAndroid Build Coastguard Worker 6753*da0073e9SAndroid Build Coastguard Worker check_cast('bool', 'int', True) 6754*da0073e9SAndroid Build Coastguard Worker check_cast('bool', 'float', True) 6755*da0073e9SAndroid Build Coastguard Worker 6756*da0073e9SAndroid Build Coastguard Worker def test_multiple_assignment(self): 6757*da0073e9SAndroid Build Coastguard Worker def outer_func(x): 6758*da0073e9SAndroid Build Coastguard Worker return x * 2, x + 2 6759*da0073e9SAndroid Build Coastguard Worker 6760*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 6761*da0073e9SAndroid Build Coastguard Worker def func(x): 6762*da0073e9SAndroid Build Coastguard Worker y, z = outer_func(x) 6763*da0073e9SAndroid Build Coastguard Worker return y + z 6764*da0073e9SAndroid Build Coastguard Worker 6765*da0073e9SAndroid Build Coastguard Worker x = torch.arange(4) 6766*da0073e9SAndroid Build Coastguard Worker self.assertEqual(func(x), x * 2 + x + 2) 6767*da0073e9SAndroid Build Coastguard Worker 6768*da0073e9SAndroid Build Coastguard Worker def test_literals(self): 6769*da0073e9SAndroid Build Coastguard Worker def func(a): 6770*da0073e9SAndroid Build Coastguard Worker return a.view(size=[1, 2, 3]) 6771*da0073e9SAndroid Build Coastguard Worker 6772*da0073e9SAndroid Build Coastguard Worker a = torch.randn(6) 6773*da0073e9SAndroid Build Coastguard Worker self.checkScript(func, [a], optimize=True) 6774*da0073e9SAndroid Build Coastguard Worker 6775*da0073e9SAndroid Build Coastguard Worker def test_return(self): 6776*da0073e9SAndroid Build Coastguard Worker def no_return(a): 6777*da0073e9SAndroid Build Coastguard Worker a + 1 6778*da0073e9SAndroid Build Coastguard Worker 6779*da0073e9SAndroid Build Coastguard Worker def void_return(a): 6780*da0073e9SAndroid Build Coastguard Worker return 6781*da0073e9SAndroid Build Coastguard Worker 6782*da0073e9SAndroid Build Coastguard Worker def one_return(a): 6783*da0073e9SAndroid Build Coastguard Worker return a + 1. 6784*da0073e9SAndroid Build Coastguard Worker 6785*da0073e9SAndroid Build Coastguard Worker def multiple_returns(a): 6786*da0073e9SAndroid Build Coastguard Worker return a * 1., a * 2., a * 3. 6787*da0073e9SAndroid Build Coastguard Worker 6788*da0073e9SAndroid Build Coastguard Worker a = torch.randn(1, dtype=torch.float) 6789*da0073e9SAndroid Build Coastguard Worker self.checkScript(no_return, [a], optimize=True) 6790*da0073e9SAndroid Build Coastguard Worker self.checkScript(void_return, [a], optimize=True) 6791*da0073e9SAndroid Build Coastguard Worker self.checkScript(one_return, [a], optimize=True) 6792*da0073e9SAndroid Build Coastguard Worker self.checkScript(multiple_returns, [a], optimize=True) 6793*da0073e9SAndroid Build Coastguard Worker 6794*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "does not return along all paths"): 6795*da0073e9SAndroid Build Coastguard Worker torch.jit.CompilationUnit(''' 6796*da0073e9SAndroid Build Coastguard Worker def no_return_bad_annotation(a): 6797*da0073e9SAndroid Build Coastguard Worker # type: (Tensor) -> Tensor 6798*da0073e9SAndroid Build Coastguard Worker a + 1 6799*da0073e9SAndroid Build Coastguard Worker ''') 6800*da0073e9SAndroid Build Coastguard Worker 6801*da0073e9SAndroid Build Coastguard Worker def test_error(self): 6802*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 6803*da0073e9SAndroid Build Coastguard Worker def foo(a): 6804*da0073e9SAndroid Build Coastguard Worker return a.t() 6805*da0073e9SAndroid Build Coastguard Worker s = Variable(torch.rand(5, 5, 5)) 6806*da0073e9SAndroid Build Coastguard Worker # XXX: this should stay quiet in stay propagation and only fail in the interpreter 6807*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "failed in the TorchScript interpreter"): 6808*da0073e9SAndroid Build Coastguard Worker foo(s) 6809*da0073e9SAndroid Build Coastguard Worker 6810*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 6811*da0073e9SAndroid Build Coastguard Worker def bar(c, b): 6812*da0073e9SAndroid Build Coastguard Worker return c + b 6813*da0073e9SAndroid Build Coastguard Worker 6814*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "failed in the TorchScript interpreter"): 6815*da0073e9SAndroid Build Coastguard Worker bar(Variable(torch.rand(10), requires_grad=True), Variable(torch.rand(9), requires_grad=True)) 6816*da0073e9SAndroid Build Coastguard Worker 6817*da0073e9SAndroid Build Coastguard Worker def test_error_stacktrace(self): 6818*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 6819*da0073e9SAndroid Build Coastguard Worker def baz(c, b): 6820*da0073e9SAndroid Build Coastguard Worker return c + b 6821*da0073e9SAndroid Build Coastguard Worker 6822*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 6823*da0073e9SAndroid Build Coastguard Worker def foo(c, b): 6824*da0073e9SAndroid Build Coastguard Worker return baz(c, b) 6825*da0073e9SAndroid Build Coastguard Worker 6826*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 6827*da0073e9SAndroid Build Coastguard Worker def bar(c, b): 6828*da0073e9SAndroid Build Coastguard Worker return foo(c, b) 6829*da0073e9SAndroid Build Coastguard Worker 6830*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError) as cm: 6831*da0073e9SAndroid Build Coastguard Worker bar(torch.rand(10), torch.rand(9)) 6832*da0073e9SAndroid Build Coastguard Worker FileCheck().check("The following operation failed in the TorchScript interpreter") \ 6833*da0073e9SAndroid Build Coastguard Worker .check("Traceback") \ 6834*da0073e9SAndroid Build Coastguard Worker .check("in foo").check("in baz").run(str(cm.exception)) 6835*da0073e9SAndroid Build Coastguard Worker 6836*da0073e9SAndroid Build Coastguard Worker def test_error_stacktrace_interface(self): 6837*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 6838*da0073e9SAndroid Build Coastguard Worker def baz(c, b): 6839*da0073e9SAndroid Build Coastguard Worker return c + b 6840*da0073e9SAndroid Build Coastguard Worker 6841*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 6842*da0073e9SAndroid Build Coastguard Worker def foo(c, b): 6843*da0073e9SAndroid Build Coastguard Worker return baz(c, b) 6844*da0073e9SAndroid Build Coastguard Worker 6845*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 6846*da0073e9SAndroid Build Coastguard Worker def bar(c, b): 6847*da0073e9SAndroid Build Coastguard Worker return foo(c, b) 6848*da0073e9SAndroid Build Coastguard Worker 6849*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 6850*da0073e9SAndroid Build Coastguard Worker class Bar: 6851*da0073e9SAndroid Build Coastguard Worker def one(self, x, y): 6852*da0073e9SAndroid Build Coastguard Worker return bar(x, y) 6853*da0073e9SAndroid Build Coastguard Worker 6854*da0073e9SAndroid Build Coastguard Worker @torch.jit.interface 6855*da0073e9SAndroid Build Coastguard Worker class IFace: 6856*da0073e9SAndroid Build Coastguard Worker def one(self, x, y): 6857*da0073e9SAndroid Build Coastguard Worker # type: (Tensor, Tensor) -> Tensor 6858*da0073e9SAndroid Build Coastguard Worker pass 6859*da0073e9SAndroid Build Coastguard Worker 6860*da0073e9SAndroid Build Coastguard Worker make_global(IFace) 6861*da0073e9SAndroid Build Coastguard Worker 6862*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 6863*da0073e9SAndroid Build Coastguard Worker def as_interface(x): 6864*da0073e9SAndroid Build Coastguard Worker # type: (IFace) -> IFace 6865*da0073e9SAndroid Build Coastguard Worker return x 6866*da0073e9SAndroid Build Coastguard Worker 6867*da0073e9SAndroid Build Coastguard Worker f = as_interface(Bar()) 6868*da0073e9SAndroid Build Coastguard Worker 6869*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError) as cm: 6870*da0073e9SAndroid Build Coastguard Worker x = f.one(torch.rand(10), torch.rand(9)) 6871*da0073e9SAndroid Build Coastguard Worker bar(torch.rand(10), torch.rand(9)) 6872*da0073e9SAndroid Build Coastguard Worker FileCheck().check("The following operation failed in the TorchScript interpreter") \ 6873*da0073e9SAndroid Build Coastguard Worker .check("Traceback") \ 6874*da0073e9SAndroid Build Coastguard Worker .check("in foo").check("in baz").run(str(cm.exception)) 6875*da0073e9SAndroid Build Coastguard Worker 6876*da0073e9SAndroid Build Coastguard Worker def test_operator_precedence(self): 6877*da0073e9SAndroid Build Coastguard Worker def double(x): 6878*da0073e9SAndroid Build Coastguard Worker # type: (int) -> int 6879*da0073e9SAndroid Build Coastguard Worker return 2 * x 6880*da0073e9SAndroid Build Coastguard Worker 6881*da0073e9SAndroid Build Coastguard Worker def complicated_arithmetic_operation(): 6882*da0073e9SAndroid Build Coastguard Worker # TODO we need to test exponent operator '**' and bitwise not 6883*da0073e9SAndroid Build Coastguard Worker # operator '~' once they are properly supported. 6884*da0073e9SAndroid Build Coastguard Worker list = [0, 1, 2, 3] 6885*da0073e9SAndroid Build Coastguard Worker result = list[1:3][0] + double(4) + (-3 + 8) * 6 // 2 % 4 << 2 + 1 >> 1 | 23 & 16 + 3 ^ 4 6886*da0073e9SAndroid Build Coastguard Worker return result 6887*da0073e9SAndroid Build Coastguard Worker 6888*da0073e9SAndroid Build Coastguard Worker self.checkScript(complicated_arithmetic_operation, ()) 6889*da0073e9SAndroid Build Coastguard Worker 6890*da0073e9SAndroid Build Coastguard Worker def test_in_operator_with_two_strings(self): 6891*da0073e9SAndroid Build Coastguard Worker def fn() -> bool: 6892*da0073e9SAndroid Build Coastguard Worker return "a" in "abcd" 6893*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, ()) 6894*da0073e9SAndroid Build Coastguard Worker 6895*da0073e9SAndroid Build Coastguard Worker def test_bitwise_ops(self): 6896*da0073e9SAndroid Build Coastguard Worker 6897*da0073e9SAndroid Build Coastguard Worker def int_test(): 6898*da0073e9SAndroid Build Coastguard Worker return 2 & 3, 2 ^ 3, 2 | 3, 2 << 3, 2 >> 3 6899*da0073e9SAndroid Build Coastguard Worker 6900*da0073e9SAndroid Build Coastguard Worker self.checkScript(int_test, ()) 6901*da0073e9SAndroid Build Coastguard Worker 6902*da0073e9SAndroid Build Coastguard Worker def bool_test(x, y): 6903*da0073e9SAndroid Build Coastguard Worker # type: (bool, bool) -> Tuple[bool, bool, bool] 6904*da0073e9SAndroid Build Coastguard Worker return x & y, x ^ y, x | y 6905*da0073e9SAndroid Build Coastguard Worker 6906*da0073e9SAndroid Build Coastguard Worker self.checkScript(bool_test, (True, False)) 6907*da0073e9SAndroid Build Coastguard Worker self.checkScript(bool_test, (True, True)) 6908*da0073e9SAndroid Build Coastguard Worker 6909*da0073e9SAndroid Build Coastguard Worker def tensor_test(x, y): 6910*da0073e9SAndroid Build Coastguard Worker return x & y, x ^ y, x | y 6911*da0073e9SAndroid Build Coastguard Worker 6912*da0073e9SAndroid Build Coastguard Worker def tensor_with_int_test(x, y): 6913*da0073e9SAndroid Build Coastguard Worker # type: (Tensor, int) -> Tuple[Tensor, Tensor] 6914*da0073e9SAndroid Build Coastguard Worker return x << y, x >> y 6915*da0073e9SAndroid Build Coastguard Worker 6916*da0073e9SAndroid Build Coastguard Worker x = torch.tensor(2) 6917*da0073e9SAndroid Build Coastguard Worker y = torch.tensor(3) 6918*da0073e9SAndroid Build Coastguard Worker 6919*da0073e9SAndroid Build Coastguard Worker self.checkScript(tensor_test, (x, y)) 6920*da0073e9SAndroid Build Coastguard Worker self.checkScript(tensor_with_int_test, (x, 2)) 6921*da0073e9SAndroid Build Coastguard Worker 6922*da0073e9SAndroid Build Coastguard Worker def not_test(x): 6923*da0073e9SAndroid Build Coastguard Worker return ~x 6924*da0073e9SAndroid Build Coastguard Worker 6925*da0073e9SAndroid Build Coastguard Worker self.checkScript(not_test, (torch.tensor([2, 4]), )) 6926*da0073e9SAndroid Build Coastguard Worker 6927*da0073e9SAndroid Build Coastguard Worker def test_all(self): 6928*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 6929*da0073e9SAndroid Build Coastguard Worker def test_all_tensor(x): 6930*da0073e9SAndroid Build Coastguard Worker return all(x) 6931*da0073e9SAndroid Build Coastguard Worker self.assertFalse(test_all_tensor(torch.tensor([1, 0, 3], dtype=torch.uint8))) 6932*da0073e9SAndroid Build Coastguard Worker self.assertTrue(test_all_tensor(torch.tensor([3.14, 3, 99], dtype=torch.uint8))) 6933*da0073e9SAndroid Build Coastguard Worker self.assertTrue(test_all_tensor(torch.tensor([True, True], dtype=torch.uint8))) 6934*da0073e9SAndroid Build Coastguard Worker self.assertFalse(test_all_tensor(torch.tensor([True, False], dtype=torch.uint8))) 6935*da0073e9SAndroid Build Coastguard Worker 6936*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 6937*da0073e9SAndroid Build Coastguard Worker def test_all_bool_list(x): 6938*da0073e9SAndroid Build Coastguard Worker # type: (List[bool]) -> bool 6939*da0073e9SAndroid Build Coastguard Worker return all(x) 6940*da0073e9SAndroid Build Coastguard Worker self.assertTrue(test_all_bool_list([True, True])) 6941*da0073e9SAndroid Build Coastguard Worker self.assertTrue(test_all_bool_list([True, 1])) 6942*da0073e9SAndroid Build Coastguard Worker self.assertFalse(test_all_bool_list([True, False])) 6943*da0073e9SAndroid Build Coastguard Worker self.assertFalse(test_all_bool_list([True, 0])) 6944*da0073e9SAndroid Build Coastguard Worker self.assertFalse(test_all_bool_list([False, 0])) 6945*da0073e9SAndroid Build Coastguard Worker self.assertTrue(test_all_bool_list([])) 6946*da0073e9SAndroid Build Coastguard Worker 6947*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 6948*da0073e9SAndroid Build Coastguard Worker def test_all_int_list(x): 6949*da0073e9SAndroid Build Coastguard Worker # type: (List[int]) -> bool 6950*da0073e9SAndroid Build Coastguard Worker return all(x) 6951*da0073e9SAndroid Build Coastguard Worker self.assertTrue(test_all_int_list([3, 6])) 6952*da0073e9SAndroid Build Coastguard Worker self.assertFalse(test_all_int_list([2, 0])) 6953*da0073e9SAndroid Build Coastguard Worker 6954*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 6955*da0073e9SAndroid Build Coastguard Worker def test_all_float_list(x): 6956*da0073e9SAndroid Build Coastguard Worker # type: (List[float]) -> bool 6957*da0073e9SAndroid Build Coastguard Worker return all(x) 6958*da0073e9SAndroid Build Coastguard Worker self.assertTrue(test_all_float_list([3.14, 8.1])) 6959*da0073e9SAndroid Build Coastguard Worker self.assertFalse(test_all_float_list([3.14, 0, 8.9])) 6960*da0073e9SAndroid Build Coastguard Worker 6961*da0073e9SAndroid Build Coastguard Worker 6962*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("Not a TorchDynamo suitable test") 6963*da0073e9SAndroid Build Coastguard Worker def test_number_math(self): 6964*da0073e9SAndroid Build Coastguard Worker ops_template = dedent(''' 6965*da0073e9SAndroid Build Coastguard Worker def func(): 6966*da0073e9SAndroid Build Coastguard Worker return {scalar1} {op} {scalar2} 6967*da0073e9SAndroid Build Coastguard Worker ''') 6968*da0073e9SAndroid Build Coastguard Worker ops = ['+', '-', '*', '%', '<', '<=', '>', '>=', '==', '!=', '//'] 6969*da0073e9SAndroid Build Coastguard Worker funcs_template = dedent(''' 6970*da0073e9SAndroid Build Coastguard Worker def func(): 6971*da0073e9SAndroid Build Coastguard Worker return {func}({scalar1}, {scalar2}) 6972*da0073e9SAndroid Build Coastguard Worker ''') 6973*da0073e9SAndroid Build Coastguard Worker funcs = ['min', 'max'] 6974*da0073e9SAndroid Build Coastguard Worker scalars = ['7', '2', '3', '-3', '3.14', '0.125', '-0.5', '2.0', '-2.0'] 6975*da0073e9SAndroid Build Coastguard Worker scalar_pairs = [(scalar1, scalar2) for scalar1 in scalars for scalar2 in scalars] 6976*da0073e9SAndroid Build Coastguard Worker 6977*da0073e9SAndroid Build Coastguard Worker def run_test(code): 6978*da0073e9SAndroid Build Coastguard Worker scope = {} 6979*da0073e9SAndroid Build Coastguard Worker execWrapper(code, globals(), scope) 6980*da0073e9SAndroid Build Coastguard Worker cu = torch.jit.CompilationUnit(code) 6981*da0073e9SAndroid Build Coastguard Worker 6982*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cu.func(), scope['func']()) 6983*da0073e9SAndroid Build Coastguard Worker 6984*da0073e9SAndroid Build Coastguard Worker for scalar1, scalar2 in scalar_pairs: 6985*da0073e9SAndroid Build Coastguard Worker for op in ops: 6986*da0073e9SAndroid Build Coastguard Worker code = ops_template.format(op=op, scalar1=scalar1, scalar2=scalar2) 6987*da0073e9SAndroid Build Coastguard Worker run_test(code) 6988*da0073e9SAndroid Build Coastguard Worker for func in funcs: 6989*da0073e9SAndroid Build Coastguard Worker code = funcs_template.format(func=func, scalar1=scalar1, scalar2=scalar2) 6990*da0073e9SAndroid Build Coastguard Worker run_test(code) 6991*da0073e9SAndroid Build Coastguard Worker 6992*da0073e9SAndroid Build Coastguard Worker # test Scalar overloads 6993*da0073e9SAndroid Build Coastguard Worker for scalar1, scalar2 in scalar_pairs: 6994*da0073e9SAndroid Build Coastguard Worker item1 = 'torch.tensor(' + scalar1 + ').item()' 6995*da0073e9SAndroid Build Coastguard Worker item2 = 'torch.tensor(' + scalar2 + ').item()' 6996*da0073e9SAndroid Build Coastguard Worker for op in ops: 6997*da0073e9SAndroid Build Coastguard Worker code = ops_template.format(op=op, scalar1=item1, scalar2=scalar2) 6998*da0073e9SAndroid Build Coastguard Worker run_test(code) 6999*da0073e9SAndroid Build Coastguard Worker code = ops_template.format(op=op, scalar1=scalar1, scalar2=item2) 7000*da0073e9SAndroid Build Coastguard Worker run_test(code) 7001*da0073e9SAndroid Build Coastguard Worker code = ops_template.format(op=op, scalar1=item1, scalar2=item2) 7002*da0073e9SAndroid Build Coastguard Worker run_test(code) 7003*da0073e9SAndroid Build Coastguard Worker for func in funcs: 7004*da0073e9SAndroid Build Coastguard Worker code = funcs_template.format(func=func, scalar1=item1, scalar2=scalar2) 7005*da0073e9SAndroid Build Coastguard Worker run_test(code) 7006*da0073e9SAndroid Build Coastguard Worker code = funcs_template.format(func=func, scalar1=scalar1, scalar2=item2) 7007*da0073e9SAndroid Build Coastguard Worker run_test(code) 7008*da0073e9SAndroid Build Coastguard Worker code = funcs_template.format(func=func, scalar1=item1, scalar2=item2) 7009*da0073e9SAndroid Build Coastguard Worker run_test(code) 7010*da0073e9SAndroid Build Coastguard Worker 7011*da0073e9SAndroid Build Coastguard Worker def test_number_abs(self): 7012*da0073e9SAndroid Build Coastguard Worker def func1(x): 7013*da0073e9SAndroid Build Coastguard Worker # type: (float) -> float 7014*da0073e9SAndroid Build Coastguard Worker return abs(x) 7015*da0073e9SAndroid Build Coastguard Worker 7016*da0073e9SAndroid Build Coastguard Worker def func2(x): 7017*da0073e9SAndroid Build Coastguard Worker # type: (int) -> int 7018*da0073e9SAndroid Build Coastguard Worker return abs(x) 7019*da0073e9SAndroid Build Coastguard Worker 7020*da0073e9SAndroid Build Coastguard Worker def func3(x): 7021*da0073e9SAndroid Build Coastguard Worker return abs(x) 7022*da0073e9SAndroid Build Coastguard Worker 7023*da0073e9SAndroid Build Coastguard Worker self.checkScript(func1, (-3.14,)) 7024*da0073e9SAndroid Build Coastguard Worker self.checkScript(func1, (3.14,)) 7025*da0073e9SAndroid Build Coastguard Worker self.checkScript(func2, (-10,)) 7026*da0073e9SAndroid Build Coastguard Worker self.checkScript(func2, (10,)) 7027*da0073e9SAndroid Build Coastguard Worker self.checkScript(func3, (torch.tensor([-5, -10, -20]),)) 7028*da0073e9SAndroid Build Coastguard Worker self.checkScript(func3, (torch.tensor([5, 10, 20]),)) 7029*da0073e9SAndroid Build Coastguard Worker self.checkScript(func3, (torch.tensor([-5, 10, -20]),)) 7030*da0073e9SAndroid Build Coastguard Worker 7031*da0073e9SAndroid Build Coastguard Worker def test_number_div(self): 7032*da0073e9SAndroid Build Coastguard Worker self.assertEqual(div_int_future(), torch.jit.script(div_int_future)()) 7033*da0073e9SAndroid Build Coastguard Worker self.checkScript(div_float_future, ()) 7034*da0073e9SAndroid Build Coastguard Worker 7035*da0073e9SAndroid Build Coastguard Worker self.checkScript(div_int_nofuture, ()) 7036*da0073e9SAndroid Build Coastguard Worker self.checkScript(div_float_nofuture, ()) 7037*da0073e9SAndroid Build Coastguard Worker 7038*da0073e9SAndroid Build Coastguard Worker # Testing bitwise shorthand aug assignment 7039*da0073e9SAndroid Build Coastguard Worker def test_bool_augassign_bitwise_or(self): 7040*da0073e9SAndroid Build Coastguard Worker def func(a: bool, b: bool) -> bool: 7041*da0073e9SAndroid Build Coastguard Worker a |= b 7042*da0073e9SAndroid Build Coastguard Worker return a 7043*da0073e9SAndroid Build Coastguard Worker 7044*da0073e9SAndroid Build Coastguard Worker self.checkScript(func, (True, False), optimize=True) 7045*da0073e9SAndroid Build Coastguard Worker self.checkScript(func, (True, True), optimize=True) 7046*da0073e9SAndroid Build Coastguard Worker self.checkScript(func, (False, False), optimize=True) 7047*da0073e9SAndroid Build Coastguard Worker self.checkScript(func, (False, True), optimize=True) 7048*da0073e9SAndroid Build Coastguard Worker 7049*da0073e9SAndroid Build Coastguard Worker def test_bool_augassign_bitwise_and(self): 7050*da0073e9SAndroid Build Coastguard Worker def func(a: bool, b: bool) -> bool: 7051*da0073e9SAndroid Build Coastguard Worker a &= b 7052*da0073e9SAndroid Build Coastguard Worker return a 7053*da0073e9SAndroid Build Coastguard Worker 7054*da0073e9SAndroid Build Coastguard Worker self.checkScript(func, (True, False), optimize=True) 7055*da0073e9SAndroid Build Coastguard Worker self.checkScript(func, (True, True), optimize=True) 7056*da0073e9SAndroid Build Coastguard Worker self.checkScript(func, (False, False), optimize=True) 7057*da0073e9SAndroid Build Coastguard Worker self.checkScript(func, (False, True), optimize=True) 7058*da0073e9SAndroid Build Coastguard Worker 7059*da0073e9SAndroid Build Coastguard Worker def test_bool_augassign_bitwise_xor(self): 7060*da0073e9SAndroid Build Coastguard Worker def func(a: bool, b: bool) -> bool: 7061*da0073e9SAndroid Build Coastguard Worker a ^= b 7062*da0073e9SAndroid Build Coastguard Worker return a 7063*da0073e9SAndroid Build Coastguard Worker 7064*da0073e9SAndroid Build Coastguard Worker self.checkScript(func, (True, False), optimize=True) 7065*da0073e9SAndroid Build Coastguard Worker self.checkScript(func, (True, True), optimize=True) 7066*da0073e9SAndroid Build Coastguard Worker self.checkScript(func, (False, False), optimize=True) 7067*da0073e9SAndroid Build Coastguard Worker self.checkScript(func, (False, True), optimize=True) 7068*da0073e9SAndroid Build Coastguard Worker 7069*da0073e9SAndroid Build Coastguard Worker def test_number_augassign_bitwise_lshift(self): 7070*da0073e9SAndroid Build Coastguard Worker def func() -> int: 7071*da0073e9SAndroid Build Coastguard Worker z = 8 7072*da0073e9SAndroid Build Coastguard Worker z <<= 2 7073*da0073e9SAndroid Build Coastguard Worker return z 7074*da0073e9SAndroid Build Coastguard Worker 7075*da0073e9SAndroid Build Coastguard Worker self.checkScript(func, (), optimize=True) 7076*da0073e9SAndroid Build Coastguard Worker 7077*da0073e9SAndroid Build Coastguard Worker def test_number_augassign_bitwise_rshift(self): 7078*da0073e9SAndroid Build Coastguard Worker def func() -> int: 7079*da0073e9SAndroid Build Coastguard Worker z = 8 7080*da0073e9SAndroid Build Coastguard Worker z >>= 2 7081*da0073e9SAndroid Build Coastguard Worker return z 7082*da0073e9SAndroid Build Coastguard Worker 7083*da0073e9SAndroid Build Coastguard Worker self.checkScript(func, (), optimize=True) 7084*da0073e9SAndroid Build Coastguard Worker 7085*da0073e9SAndroid Build Coastguard Worker def test_number_augassign_bitwise_pow(self): 7086*da0073e9SAndroid Build Coastguard Worker def func() -> float: 7087*da0073e9SAndroid Build Coastguard Worker z = 8 7088*da0073e9SAndroid Build Coastguard Worker z **= 2 7089*da0073e9SAndroid Build Coastguard Worker return z 7090*da0073e9SAndroid Build Coastguard Worker 7091*da0073e9SAndroid Build Coastguard Worker self.checkScript(func, (), optimize=True) 7092*da0073e9SAndroid Build Coastguard Worker 7093*da0073e9SAndroid Build Coastguard Worker def test_number_augassign(self): 7094*da0073e9SAndroid Build Coastguard Worker def func(): 7095*da0073e9SAndroid Build Coastguard Worker z = 1 7096*da0073e9SAndroid Build Coastguard Worker z += 2 7097*da0073e9SAndroid Build Coastguard Worker return z 7098*da0073e9SAndroid Build Coastguard Worker 7099*da0073e9SAndroid Build Coastguard Worker self.checkScript(func, (), optimize=True) 7100*da0073e9SAndroid Build Coastguard Worker 7101*da0073e9SAndroid Build Coastguard Worker def test_nested_select_assign(self): 7102*da0073e9SAndroid Build Coastguard Worker class SubSubModule(torch.nn.Module): 7103*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 7104*da0073e9SAndroid Build Coastguard Worker super().__init__() 7105*da0073e9SAndroid Build Coastguard Worker self.abc = 11 7106*da0073e9SAndroid Build Coastguard Worker 7107*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 7108*da0073e9SAndroid Build Coastguard Worker return self.abc 7109*da0073e9SAndroid Build Coastguard Worker 7110*da0073e9SAndroid Build Coastguard Worker class SubModule(torch.nn.Module): 7111*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 7112*da0073e9SAndroid Build Coastguard Worker super().__init__() 7113*da0073e9SAndroid Build Coastguard Worker self.a = 11 7114*da0073e9SAndroid Build Coastguard Worker self.nested = SubSubModule() 7115*da0073e9SAndroid Build Coastguard Worker 7116*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 7117*da0073e9SAndroid Build Coastguard Worker return self.a 7118*da0073e9SAndroid Build Coastguard Worker 7119*da0073e9SAndroid Build Coastguard Worker class TestModule(torch.nn.Module): 7120*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 7121*da0073e9SAndroid Build Coastguard Worker super().__init__() 7122*da0073e9SAndroid Build Coastguard Worker self.sub = SubModule() 7123*da0073e9SAndroid Build Coastguard Worker self.hi = 1 7124*da0073e9SAndroid Build Coastguard Worker 7125*da0073e9SAndroid Build Coastguard Worker def forward(self): 7126*da0073e9SAndroid Build Coastguard Worker self.hi = 5 7127*da0073e9SAndroid Build Coastguard Worker self.sub.a = 1 7128*da0073e9SAndroid Build Coastguard Worker self.sub.nested.abc = 5 7129*da0073e9SAndroid Build Coastguard Worker return self.sub.a * 20 + self.sub.nested.abc * 3 + self.hi 7130*da0073e9SAndroid Build Coastguard Worker 7131*da0073e9SAndroid Build Coastguard Worker self.checkModule(TestModule(), ()) 7132*da0073e9SAndroid Build Coastguard Worker 7133*da0073e9SAndroid Build Coastguard Worker def test_number_neg(self): 7134*da0073e9SAndroid Build Coastguard Worker # int -> int 7135*da0073e9SAndroid Build Coastguard Worker def func1(): 7136*da0073e9SAndroid Build Coastguard Worker return -8 7137*da0073e9SAndroid Build Coastguard Worker 7138*da0073e9SAndroid Build Coastguard Worker # float -> float 7139*da0073e9SAndroid Build Coastguard Worker def func2(): 7140*da0073e9SAndroid Build Coastguard Worker return -3.14 7141*da0073e9SAndroid Build Coastguard Worker 7142*da0073e9SAndroid Build Coastguard Worker self.checkScript(func1, (), optimize=True) 7143*da0073e9SAndroid Build Coastguard Worker self.checkScript(func2, (), optimize=True) 7144*da0073e9SAndroid Build Coastguard Worker 7145*da0073e9SAndroid Build Coastguard Worker def test_compare_two_bool_inputs(self): 7146*da0073e9SAndroid Build Coastguard Worker def compare_eq(a: bool, b: bool): 7147*da0073e9SAndroid Build Coastguard Worker return a == b 7148*da0073e9SAndroid Build Coastguard Worker 7149*da0073e9SAndroid Build Coastguard Worker def compare_ne(a: bool, b: bool): 7150*da0073e9SAndroid Build Coastguard Worker return a != b 7151*da0073e9SAndroid Build Coastguard Worker 7152*da0073e9SAndroid Build Coastguard Worker scripted_fn_eq = torch.jit.script(compare_eq) 7153*da0073e9SAndroid Build Coastguard Worker scripted_fn_ne = torch.jit.script(compare_ne) 7154*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scripted_fn_eq(True, False), compare_eq(True, False)) 7155*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scripted_fn_eq(False, True), compare_eq(False, True)) 7156*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scripted_fn_eq(True, True), compare_eq(True, True)) 7157*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scripted_fn_eq(False, False), compare_eq(False, False)) 7158*da0073e9SAndroid Build Coastguard Worker 7159*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scripted_fn_ne(True, False), compare_ne(True, False)) 7160*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scripted_fn_ne(False, True), compare_ne(False, True)) 7161*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scripted_fn_ne(True, True), compare_ne(True, True)) 7162*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scripted_fn_ne(False, False), compare_ne(False, False)) 7163*da0073e9SAndroid Build Coastguard Worker 7164*da0073e9SAndroid Build Coastguard Worker 7165*da0073e9SAndroid Build Coastguard Worker def _test_tensor_number_math(self, device='cpu'): 7166*da0073e9SAndroid Build Coastguard Worker template = dedent(''' 7167*da0073e9SAndroid Build Coastguard Worker def func(t): 7168*da0073e9SAndroid Build Coastguard Worker return {lhs} {op} {rhs} 7169*da0073e9SAndroid Build Coastguard Worker ''') 7170*da0073e9SAndroid Build Coastguard Worker 7171*da0073e9SAndroid Build Coastguard Worker def test(op, tensor, const, swap_args, template=template): 7172*da0073e9SAndroid Build Coastguard Worker args = ('t', const) 7173*da0073e9SAndroid Build Coastguard Worker if swap_args: 7174*da0073e9SAndroid Build Coastguard Worker args = (const, 't') 7175*da0073e9SAndroid Build Coastguard Worker 7176*da0073e9SAndroid Build Coastguard Worker code = template.format(lhs=args[0], rhs=args[1], op=op) 7177*da0073e9SAndroid Build Coastguard Worker scope = {} 7178*da0073e9SAndroid Build Coastguard Worker execWrapper(code, globals(), scope) 7179*da0073e9SAndroid Build Coastguard Worker cu = torch.jit.CompilationUnit(code) 7180*da0073e9SAndroid Build Coastguard Worker message = f'with code `{args[0]} {op} {args[1]}` and t={tensor}' 7181*da0073e9SAndroid Build Coastguard Worker res1 = cu.func(tensor) 7182*da0073e9SAndroid Build Coastguard Worker res2 = scope['func'](tensor) 7183*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res1, res2, msg=message + "\nres1=" + str(res1) + "\nres2=" + str(res2)) 7184*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res1.dtype, res2.dtype, msg=message + "\nres1=" + str(res1) + "\nres2=" + str(res2)) 7185*da0073e9SAndroid Build Coastguard Worker 7186*da0073e9SAndroid Build Coastguard Worker var_int = [2, -2] 7187*da0073e9SAndroid Build Coastguard Worker var_float = [1.4321, -1.2] 7188*da0073e9SAndroid Build Coastguard Worker 7189*da0073e9SAndroid Build Coastguard Worker ops = ['+', '-', '*', '%', '<', '<=', '>', '>=', '==', '!=', '/'] 7190*da0073e9SAndroid Build Coastguard Worker 7191*da0073e9SAndroid Build Coastguard Worker float_tensor = torch.randn(5, 5, device=device) 7192*da0073e9SAndroid Build Coastguard Worker double_tensor = torch.randn(5, 5, dtype=torch.double, device=device) 7193*da0073e9SAndroid Build Coastguard Worker long_tensor = torch.randint(-5, 5, (5, 5), dtype=torch.long, device=device) 7194*da0073e9SAndroid Build Coastguard Worker long_tensor[long_tensor == 0] = 2 7195*da0073e9SAndroid Build Coastguard Worker 7196*da0073e9SAndroid Build Coastguard Worker tensors = [float_tensor, double_tensor, long_tensor] 7197*da0073e9SAndroid Build Coastguard Worker consts = var_int + var_float 7198*da0073e9SAndroid Build Coastguard Worker 7199*da0073e9SAndroid Build Coastguard Worker for op, tensor, const, swap_args in product(ops, tensors, consts, [True, False]): 7200*da0073e9SAndroid Build Coastguard Worker # FIXME: things like 2 / long_tensor are not implemented correctly 7201*da0073e9SAndroid Build Coastguard Worker # Look in torch/_tensor.py to see how pytorch implements it. 7202*da0073e9SAndroid Build Coastguard Worker if op == '/' and tensor.data_ptr() == long_tensor.data_ptr(): 7203*da0073e9SAndroid Build Coastguard Worker continue 7204*da0073e9SAndroid Build Coastguard Worker 7205*da0073e9SAndroid Build Coastguard Worker # % operator does not take: const % tensor 7206*da0073e9SAndroid Build Coastguard Worker if op == '%' and swap_args is True: 7207*da0073e9SAndroid Build Coastguard Worker continue 7208*da0073e9SAndroid Build Coastguard Worker 7209*da0073e9SAndroid Build Coastguard Worker test(op, tensor, const, swap_args) 7210*da0073e9SAndroid Build Coastguard Worker 7211*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("Not a TorchDynamo suitable test") 7212*da0073e9SAndroid Build Coastguard Worker def test_tensor_number_math(self): 7213*da0073e9SAndroid Build Coastguard Worker self._test_tensor_number_math() 7214*da0073e9SAndroid Build Coastguard Worker 7215*da0073e9SAndroid Build Coastguard Worker def test_torch_tensor_bad_input(self): 7216*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "must be of ints, floats, " 7217*da0073e9SAndroid Build Coastguard Worker "or bools, got None"): 7218*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 7219*da0073e9SAndroid Build Coastguard Worker def test(): 7220*da0073e9SAndroid Build Coastguard Worker return torch.tensor([None]) 7221*da0073e9SAndroid Build Coastguard Worker test() 7222*da0073e9SAndroid Build Coastguard Worker 7223*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r"Empty lists default to List\[Tensor\]"): 7224*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 7225*da0073e9SAndroid Build Coastguard Worker def tmp(): 7226*da0073e9SAndroid Build Coastguard Worker return torch.tensor([]) 7227*da0073e9SAndroid Build Coastguard Worker tmp() 7228*da0073e9SAndroid Build Coastguard Worker 7229*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 7230*da0073e9SAndroid Build Coastguard Worker def foo(): 7231*da0073e9SAndroid Build Coastguard Worker return torch.tensor([[2, 2], [1]]) 7232*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Expected sequence of length"): 7233*da0073e9SAndroid Build Coastguard Worker foo() 7234*da0073e9SAndroid Build Coastguard Worker 7235*da0073e9SAndroid Build Coastguard Worker @suppress_warnings 7236*da0073e9SAndroid Build Coastguard Worker def test_torch_tensor_as_tensor_empty_list(self): 7237*da0073e9SAndroid Build Coastguard Worker tensor_template = dedent(''' 7238*da0073e9SAndroid Build Coastguard Worker def func(): 7239*da0073e9SAndroid Build Coastguard Worker empty_list = torch.jit.annotate(List[int], []) 7240*da0073e9SAndroid Build Coastguard Worker ten1 = torch.{tensor_op}({input}) 7241*da0073e9SAndroid Build Coastguard Worker return ten1 7242*da0073e9SAndroid Build Coastguard Worker ''') 7243*da0073e9SAndroid Build Coastguard Worker ops = ['tensor', 'as_tensor'] 7244*da0073e9SAndroid Build Coastguard Worker inputs = ['empty_list', '[empty_list, empty_list]', '[[[empty_list]]]'] 7245*da0073e9SAndroid Build Coastguard Worker 7246*da0073e9SAndroid Build Coastguard Worker for op in ops: 7247*da0073e9SAndroid Build Coastguard Worker for inp in inputs: 7248*da0073e9SAndroid Build Coastguard Worker code = tensor_template.format(tensor_op=op, input=inp) 7249*da0073e9SAndroid Build Coastguard Worker scope = {} 7250*da0073e9SAndroid Build Coastguard Worker exec(code, globals(), scope) 7251*da0073e9SAndroid Build Coastguard Worker cu = torch.jit.CompilationUnit(code) 7252*da0073e9SAndroid Build Coastguard Worker t1 = cu.func() 7253*da0073e9SAndroid Build Coastguard Worker t2 = scope['func']() 7254*da0073e9SAndroid Build Coastguard Worker if inp == 'empty_list': 7255*da0073e9SAndroid Build Coastguard Worker # torchscript returns int tensor, python returns float tensor 7256*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(t1.dtype, t2.dtype) 7257*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t1, t2, exact_dtype=False) 7258*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t1.device, t2.device) 7259*da0073e9SAndroid Build Coastguard Worker 7260*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "Simple Executor doesn't have any shapes to propagate") 7261*da0073e9SAndroid Build Coastguard Worker def test_tensor_as_tensor_shape_prop(self): 7262*da0073e9SAndroid Build Coastguard Worker tensor_template = dedent(''' 7263*da0073e9SAndroid Build Coastguard Worker def func(): 7264*da0073e9SAndroid Build Coastguard Worker return torch.{tensor_op}({input}) 7265*da0073e9SAndroid Build Coastguard Worker ''') 7266*da0073e9SAndroid Build Coastguard Worker ops = ['tensor', 'as_tensor'] 7267*da0073e9SAndroid Build Coastguard Worker inputs = ['[1]', '[False]', '[2.5]', '0.5', '1', 'False', '[[1]]', 'torch.jit.annotate(List[List[int]], [])'] 7268*da0073e9SAndroid Build Coastguard Worker expected_shape = ["Long(*, device=cpu)", "Bool(*, device=cpu)", 7269*da0073e9SAndroid Build Coastguard Worker "Float(*, device=cpu)", "Float(device=cpu)", 7270*da0073e9SAndroid Build Coastguard Worker "Long(device=cpu)", "Bool(device=cpu)", "Long(*, *, device=cpu)"] 7271*da0073e9SAndroid Build Coastguard Worker 7272*da0073e9SAndroid Build Coastguard Worker for op in ops: 7273*da0073e9SAndroid Build Coastguard Worker for inp, expect in zip(inputs, expected_shape): 7274*da0073e9SAndroid Build Coastguard Worker code = tensor_template.format(tensor_op=op, input=inp) 7275*da0073e9SAndroid Build Coastguard Worker scope = {} 7276*da0073e9SAndroid Build Coastguard Worker exec(code, globals(), scope) 7277*da0073e9SAndroid Build Coastguard Worker cu = torch.jit.CompilationUnit(code) 7278*da0073e9SAndroid Build Coastguard Worker torch._C._jit_pass_complete_shape_analysis(cu.func.graph, (), False) 7279*da0073e9SAndroid Build Coastguard Worker FileCheck().check(expect).check(f"aten::{op}").run(cu.func.graph) 7280*da0073e9SAndroid Build Coastguard Worker 7281*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 7282*da0073e9SAndroid Build Coastguard Worker def test_dtype(inp_dtype: torch.dtype): 7283*da0073e9SAndroid Build Coastguard Worker a = torch.tensor(1.0, dtype=torch.float, requires_grad=True) 7284*da0073e9SAndroid Build Coastguard Worker return a, torch.tensor(1.0, dtype=inp_dtype) 7285*da0073e9SAndroid Build Coastguard Worker 7286*da0073e9SAndroid Build Coastguard Worker if GRAPH_EXECUTOR == ProfilingMode.PROFILING: 7287*da0073e9SAndroid Build Coastguard Worker g = test_dtype.graph_for(5, profile_and_replay=True) 7288*da0073e9SAndroid Build Coastguard Worker # both should have completed shapes 7289*da0073e9SAndroid Build Coastguard Worker FileCheck().check("Tensor = aten::tensor").check("Float(device=cpu) = prim::BailOut") \ 7290*da0073e9SAndroid Build Coastguard Worker .check("Tensor = aten::tensor").check("Half(device=cpu) = prim::BailOut").run(g) 7291*da0073e9SAndroid Build Coastguard Worker else: 7292*da0073e9SAndroid Build Coastguard Worker g = test_dtype.graph_for(5) 7293*da0073e9SAndroid Build Coastguard Worker # first should have type set second should not 7294*da0073e9SAndroid Build Coastguard Worker FileCheck().check("Float(requires_grad=1, device=cpu) = aten::tensor") \ 7295*da0073e9SAndroid Build Coastguard Worker .check("Tensor(requires_grad=0) = aten::tensor").run(g) 7296*da0073e9SAndroid Build Coastguard Worker 7297*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 7298*da0073e9SAndroid Build Coastguard Worker def test_as_tensor_tensor_input(input): 7299*da0073e9SAndroid Build Coastguard Worker a = torch.as_tensor(input, dtype=input.dtype) 7300*da0073e9SAndroid Build Coastguard Worker return a, torch.as_tensor(input, dtype=torch.float) 7301*da0073e9SAndroid Build Coastguard Worker 7302*da0073e9SAndroid Build Coastguard Worker if GRAPH_EXECUTOR == ProfilingMode.PROFILING: 7303*da0073e9SAndroid Build Coastguard Worker g = test_as_tensor_tensor_input.graph_for(torch.ones(3, 4), profile_and_replay=True) 7304*da0073e9SAndroid Build Coastguard Worker FileCheck().check("Tensor = aten::as_tensor").check("Float(3, 4) = prim::BailOut") \ 7305*da0073e9SAndroid Build Coastguard Worker .check("Tensor = aten::as_tensor").check("Float(3, 4) = prim::BailOut").run(g) 7306*da0073e9SAndroid Build Coastguard Worker else: 7307*da0073e9SAndroid Build Coastguard Worker g = test_as_tensor_tensor_input.graph_for(torch.ones(3, 4)) 7308*da0073e9SAndroid Build Coastguard Worker FileCheck().check("Tensor = aten::as_tensor").check("Float(*, *, requires_grad=0, device=cpu) = aten::as_tensor").run(g) 7309*da0073e9SAndroid Build Coastguard Worker 7310*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "testing legacy behavior") 7311*da0073e9SAndroid Build Coastguard Worker def test_tensor_requires_grad(self): 7312*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 7313*da0073e9SAndroid Build Coastguard Worker def test(b): 7314*da0073e9SAndroid Build Coastguard Worker # type: (bool) -> Tuple[Tensor, Tensor, Tensor] 7315*da0073e9SAndroid Build Coastguard Worker a = torch.tensor(1., requires_grad=b) 7316*da0073e9SAndroid Build Coastguard Worker b = torch.tensor(1., requires_grad=True) 7317*da0073e9SAndroid Build Coastguard Worker c = torch.tensor(1., requires_grad=False) 7318*da0073e9SAndroid Build Coastguard Worker return a, b, c 7319*da0073e9SAndroid Build Coastguard Worker 7320*da0073e9SAndroid Build Coastguard Worker g = test.graph_for(True) 7321*da0073e9SAndroid Build Coastguard Worker out = next(g.outputs()) 7322*da0073e9SAndroid Build Coastguard Worker out_inp = list(out.node().inputs()) 7323*da0073e9SAndroid Build Coastguard Worker 7324*da0073e9SAndroid Build Coastguard Worker self.assertTrue(out_inp[0].requires_grad()) 7325*da0073e9SAndroid Build Coastguard Worker self.assertTrue(out_inp[1].requires_grad()) 7326*da0073e9SAndroid Build Coastguard Worker self.assertFalse(out_inp[2].requires_grad()) 7327*da0073e9SAndroid Build Coastguard Worker 7328*da0073e9SAndroid Build Coastguard Worker def test_grad_from_script(self): 7329*da0073e9SAndroid Build Coastguard Worker def test(): 7330*da0073e9SAndroid Build Coastguard Worker a = torch.tensor(2.5, requires_grad=True) 7331*da0073e9SAndroid Build Coastguard Worker b = a * 2 7332*da0073e9SAndroid Build Coastguard Worker return a, b 7333*da0073e9SAndroid Build Coastguard Worker 7334*da0073e9SAndroid Build Coastguard Worker a, b = test() 7335*da0073e9SAndroid Build Coastguard Worker b.backward() 7336*da0073e9SAndroid Build Coastguard Worker 7337*da0073e9SAndroid Build Coastguard Worker a_script, b_script = torch.jit.script(test)() 7338*da0073e9SAndroid Build Coastguard Worker b_script.backward() 7339*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.grad, a_script.grad) 7340*da0073e9SAndroid Build Coastguard Worker 7341*da0073e9SAndroid Build Coastguard Worker def test_torch_tensor_as_tensor(self): 7342*da0073e9SAndroid Build Coastguard Worker tensor_template = dedent(''' 7343*da0073e9SAndroid Build Coastguard Worker def func(): 7344*da0073e9SAndroid Build Coastguard Worker li = {list_create} 7345*da0073e9SAndroid Build Coastguard Worker ten1 = torch.{tensor_op}(li {options}) 7346*da0073e9SAndroid Build Coastguard Worker return ten1 7347*da0073e9SAndroid Build Coastguard Worker ''') 7348*da0073e9SAndroid Build Coastguard Worker 7349*da0073e9SAndroid Build Coastguard Worker lists = ["2.5", "4", "True", "False", "[2]", "[-.5]", "[False, True, False]", "[2, 2]", "(1, 1)", 7350*da0073e9SAndroid Build Coastguard Worker "torch.jit.annotate(List[List[int]], [])", 7351*da0073e9SAndroid Build Coastguard Worker "torch.jit.annotate(List[int], [])", "[2.5, 2.5]", "[[2], [2]]", "[[-.5], [2.2]]", "[[False], [True]]"] 7352*da0073e9SAndroid Build Coastguard Worker 7353*da0073e9SAndroid Build Coastguard Worker dtypes = ["", ", dtype=torch.float", ", dtype=torch.double", ", dtype=torch.half", 7354*da0073e9SAndroid Build Coastguard Worker ", dtype=torch.uint8", ", dtype=torch.int8", ", dtype=torch.short", 7355*da0073e9SAndroid Build Coastguard Worker ", dtype=torch.int", ", dtype=torch.long", ", dtype=torch.cfloat", 7356*da0073e9SAndroid Build Coastguard Worker ", dtype=torch.cdouble"] 7357*da0073e9SAndroid Build Coastguard Worker 7358*da0073e9SAndroid Build Coastguard Worker ops = ['tensor', 'as_tensor'] 7359*da0073e9SAndroid Build Coastguard Worker devices = ['', ", device='cpu'"] 7360*da0073e9SAndroid Build Coastguard Worker if RUN_CUDA: 7361*da0073e9SAndroid Build Coastguard Worker devices.append(", device='cuda'") 7362*da0073e9SAndroid Build Coastguard Worker 7363*da0073e9SAndroid Build Coastguard Worker option_pairs = [dtype + device for dtype in dtypes for device in devices] 7364*da0073e9SAndroid Build Coastguard Worker for op in ops: 7365*da0073e9SAndroid Build Coastguard Worker for li in lists: 7366*da0073e9SAndroid Build Coastguard Worker for option in option_pairs: 7367*da0073e9SAndroid Build Coastguard Worker # tensor from empty list is type float in python and annotated type in torchscript 7368*da0073e9SAndroid Build Coastguard Worker if "annotate" in li and "dtype" not in option: 7369*da0073e9SAndroid Build Coastguard Worker continue 7370*da0073e9SAndroid Build Coastguard Worker # Skip unsigned tensor initializaton for signed values on 3.10 7371*da0073e9SAndroid Build Coastguard Worker if sys.version_info[:2] >= (3, 10) and "torch.uint8" in option and "-" in li: 7372*da0073e9SAndroid Build Coastguard Worker continue 7373*da0073e9SAndroid Build Coastguard Worker code = tensor_template.format(list_create=li, tensor_op=op, options=option) 7374*da0073e9SAndroid Build Coastguard Worker scope = {} 7375*da0073e9SAndroid Build Coastguard Worker exec(code, globals(), scope) 7376*da0073e9SAndroid Build Coastguard Worker cu = torch.jit.CompilationUnit(code) 7377*da0073e9SAndroid Build Coastguard Worker t1 = cu.func() 7378*da0073e9SAndroid Build Coastguard Worker t2 = scope['func']() 7379*da0073e9SAndroid Build Coastguard Worker if t1.dtype == torch.float16: # equality NYI for half tensor 7380*da0073e9SAndroid Build Coastguard Worker self.assertTrue(str(t1) == str(t2)) 7381*da0073e9SAndroid Build Coastguard Worker else: 7382*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t1, t2) 7383*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t1.dtype, t2.dtype) 7384*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t1.device, t2.device) 7385*da0073e9SAndroid Build Coastguard Worker 7386*da0073e9SAndroid Build Coastguard Worker def test_as_tensor_tensor_input(input): 7387*da0073e9SAndroid Build Coastguard Worker # type: (Tensor) -> Tuple[Tensor, Tensor, Tensor] 7388*da0073e9SAndroid Build Coastguard Worker return torch.as_tensor(input, dtype=torch.cfloat), torch.as_tensor(input, dtype=torch.float), \ 7389*da0073e9SAndroid Build Coastguard Worker torch.as_tensor(input, dtype=torch.int32) 7390*da0073e9SAndroid Build Coastguard Worker 7391*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(3, 4, dtype=torch.cfloat) 7392*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_as_tensor_tensor_input, (inp,)) 7393*da0073e9SAndroid Build Coastguard Worker 7394*da0073e9SAndroid Build Coastguard Worker def test_torch_tensor_dtype(self): 7395*da0073e9SAndroid Build Coastguard Worker def foo(s: float): 7396*da0073e9SAndroid Build Coastguard Worker return torch.tensor(s), torch.tensor([s, s]) 7397*da0073e9SAndroid Build Coastguard Worker 7398*da0073e9SAndroid Build Coastguard Worker # need to clear function cache so we re run shape analysis 7399*da0073e9SAndroid Build Coastguard Worker with set_default_dtype(torch.double): 7400*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.jit.script(foo)(1.), foo(1.), exact_dtype=True) 7401*da0073e9SAndroid Build Coastguard Worker if GRAPH_EXECUTOR == ProfilingMode.LEGACY: 7402*da0073e9SAndroid Build Coastguard Worker FileCheck().check("Double").check_same("aten::tensor").run(torch.jit.last_executed_optimized_graph()) 7403*da0073e9SAndroid Build Coastguard Worker with set_default_dtype(torch.float): 7404*da0073e9SAndroid Build Coastguard Worker del torch.jit._state._jit_caching_layer[foo] 7405*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.jit.script(foo)(1.), foo(1.), exact_dtype=True) 7406*da0073e9SAndroid Build Coastguard Worker if GRAPH_EXECUTOR == ProfilingMode.LEGACY: 7407*da0073e9SAndroid Build Coastguard Worker FileCheck().check("Float").check_same("aten::tensor").run(torch.jit.last_executed_optimized_graph()) 7408*da0073e9SAndroid Build Coastguard Worker with set_default_dtype(torch.half): 7409*da0073e9SAndroid Build Coastguard Worker del torch.jit._state._jit_caching_layer[foo] 7410*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.jit.script(foo)(1.), foo(1.), exact_dtype=True) 7411*da0073e9SAndroid Build Coastguard Worker if GRAPH_EXECUTOR == ProfilingMode.LEGACY: 7412*da0073e9SAndroid Build Coastguard Worker FileCheck().check("Half").check_same("aten::tensor").run(torch.jit.last_executed_optimized_graph()) 7413*da0073e9SAndroid Build Coastguard Worker 7414*da0073e9SAndroid Build Coastguard Worker def test_shape_analysis_grad_property(self): 7415*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 7416*da0073e9SAndroid Build Coastguard Worker def foo(x): 7417*da0073e9SAndroid Build Coastguard Worker return torch.sub(x, torch.tanh(x)) 7418*da0073e9SAndroid Build Coastguard Worker 7419*da0073e9SAndroid Build Coastguard Worker torch._C._jit_pass_complete_shape_analysis(foo.graph, (torch.tensor([0.39]),), False) 7420*da0073e9SAndroid Build Coastguard Worker 7421*da0073e9SAndroid Build Coastguard Worker # requires_grad property shouldn't be accidentally set by shape analysis 7422*da0073e9SAndroid Build Coastguard Worker self.assertTrue(foo.graph.findNode("aten::sub").output().requiresGrad() is None) 7423*da0073e9SAndroid Build Coastguard Worker 7424*da0073e9SAndroid Build Coastguard Worker def test_empty_like_memory_format_bc(self): 7425*da0073e9SAndroid Build Coastguard Worker def f(x): 7426*da0073e9SAndroid Build Coastguard Worker # type: (Tensor) -> Tensor 7427*da0073e9SAndroid Build Coastguard Worker return torch.zeros_like(x, memory_format=None) 7428*da0073e9SAndroid Build Coastguard Worker 7429*da0073e9SAndroid Build Coastguard Worker scripted_f = torch.jit.script(f) 7430*da0073e9SAndroid Build Coastguard Worker x = torch.rand(3, 4) 7431*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scripted_f(x), f(x)) 7432*da0073e9SAndroid Build Coastguard Worker 7433*da0073e9SAndroid Build Coastguard Worker def test_multiline_string_dedents(self): 7434*da0073e9SAndroid Build Coastguard Worker def foo() -> None: 7435*da0073e9SAndroid Build Coastguard Worker multiline_string_dedent_1 = """ 7436*da0073e9SAndroid Build Coastguard WorkerThis is a string dedent """ 7437*da0073e9SAndroid Build Coastguard Worker multiline_string_dedent_2 = """ This is a 7438*da0073e9SAndroid Build Coastguard Worker string dedent """ 7439*da0073e9SAndroid Build Coastguard Worker multiline_string_dedent_3 = """ 7440*da0073e9SAndroid Build Coastguard Worker This is a string 7441*da0073e9SAndroid Build Coastguard Workerdedent """ 7442*da0073e9SAndroid Build Coastguard Worker multiline_string_dedent_4 = """ This is a string dedent """ 7443*da0073e9SAndroid Build Coastguard Worker 7444*da0073e9SAndroid Build Coastguard Worker scripted_foo = torch.jit.script(foo) 7445*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scripted_foo(), foo()) 7446*da0073e9SAndroid Build Coastguard Worker 7447*da0073e9SAndroid Build Coastguard Worker def test_class_with_comment_at_lower_indentation(self): 7448*da0073e9SAndroid Build Coastguard Worker class Foo(torch.nn.Module): 7449*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 7450*da0073e9SAndroid Build Coastguard Worker x = torch.neg(x) 7451*da0073e9SAndroid Build Coastguard Worker # This comment is at the wrong indent 7452*da0073e9SAndroid Build Coastguard Worker return x 7453*da0073e9SAndroid Build Coastguard Worker 7454*da0073e9SAndroid Build Coastguard Worker torch.jit.script(Foo()) 7455*da0073e9SAndroid Build Coastguard Worker 7456*da0073e9SAndroid Build Coastguard Worker # adapted from test in test_torch 7457*da0073e9SAndroid Build Coastguard Worker def test_tensor_to(self): 7458*da0073e9SAndroid Build Coastguard Worker template = dedent(''' 7459*da0073e9SAndroid Build Coastguard Worker def func(t): 7460*da0073e9SAndroid Build Coastguard Worker cuda = "{cuda}" 7461*da0073e9SAndroid Build Coastguard Worker device = "{device}" 7462*da0073e9SAndroid Build Coastguard Worker non_blocking = {non_blocking} 7463*da0073e9SAndroid Build Coastguard Worker return {to_str} 7464*da0073e9SAndroid Build Coastguard Worker ''') 7465*da0073e9SAndroid Build Coastguard Worker 7466*da0073e9SAndroid Build Coastguard Worker def s(t, to_str, non_blocking=None, device=None, cuda=None): 7467*da0073e9SAndroid Build Coastguard Worker device = device if device is not None else str(t.device) 7468*da0073e9SAndroid Build Coastguard Worker non_blocking = non_blocking if non_blocking is not None else False 7469*da0073e9SAndroid Build Coastguard Worker cuda = "cuda" if cuda is None else cuda 7470*da0073e9SAndroid Build Coastguard Worker code = template.format(to_str=to_str, device=device, non_blocking=non_blocking, cuda=cuda) 7471*da0073e9SAndroid Build Coastguard Worker scope = {} 7472*da0073e9SAndroid Build Coastguard Worker cu = torch.jit.CompilationUnit(code) 7473*da0073e9SAndroid Build Coastguard Worker return cu.func(t, profile_and_replay=True) 7474*da0073e9SAndroid Build Coastguard Worker 7475*da0073e9SAndroid Build Coastguard Worker def test_copy_behavior(t, non_blocking=False): 7476*da0073e9SAndroid Build Coastguard Worker self.assertIs(t, s(t, 't.to(t, non_blocking=non_blocking)', non_blocking)) 7477*da0073e9SAndroid Build Coastguard Worker self.assertIs(t, s(t, 't.to(t.dtype, non_blocking=non_blocking)', non_blocking)) 7478*da0073e9SAndroid Build Coastguard Worker self.assertIs(t, s(t, 't.to(torch.empty_like(t), non_blocking=non_blocking)', non_blocking)) 7479*da0073e9SAndroid Build Coastguard Worker self.assertIsNot(t, s(t, 't.to(t, non_blocking=non_blocking, copy=True)', non_blocking)) 7480*da0073e9SAndroid Build Coastguard Worker self.assertIsNot(t, s(t, 't.to(t.dtype, non_blocking=non_blocking, copy=True)', non_blocking)) 7481*da0073e9SAndroid Build Coastguard Worker self.assertIsNot(t, s(t, 't.to(torch.empty_like(t), non_blocking=non_blocking, copy=True)', non_blocking)) 7482*da0073e9SAndroid Build Coastguard Worker 7483*da0073e9SAndroid Build Coastguard Worker devices = [t.device] 7484*da0073e9SAndroid Build Coastguard Worker if t.device.type == 'cuda': 7485*da0073e9SAndroid Build Coastguard Worker if t.device.index == -1: 7486*da0073e9SAndroid Build Coastguard Worker devices.append(f'cuda:{torch.cuda.current_device()}') 7487*da0073e9SAndroid Build Coastguard Worker elif t.device.index == torch.cuda.current_device(): 7488*da0073e9SAndroid Build Coastguard Worker devices.append('cuda') 7489*da0073e9SAndroid Build Coastguard Worker for device in devices: 7490*da0073e9SAndroid Build Coastguard Worker self.assertIs(t, s(t, 't.to(device, non_blocking=non_blocking)', non_blocking, device)) 7491*da0073e9SAndroid Build Coastguard Worker self.assertIs(t, s(t, 't.to(device, t.dtype, non_blocking=non_blocking)', non_blocking, device)) 7492*da0073e9SAndroid Build Coastguard Worker self.assertIsNot(t, s(t, 't.to(device, non_blocking=non_blocking, copy=True)', non_blocking, device)) 7493*da0073e9SAndroid Build Coastguard Worker self.assertIsNot(t, s(t, 't.to(device, t.dtype, non_blocking=non_blocking, copy=True)', 7494*da0073e9SAndroid Build Coastguard Worker non_blocking, device)) 7495*da0073e9SAndroid Build Coastguard Worker 7496*da0073e9SAndroid Build Coastguard Worker t = torch.tensor(5) 7497*da0073e9SAndroid Build Coastguard Worker test_copy_behavior(t) 7498*da0073e9SAndroid Build Coastguard Worker 7499*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t.device, s(t, "t.to('cpu')").device) 7500*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t.device, s(t, "t.to('cpu', dtype=torch.float32)").device) 7501*da0073e9SAndroid Build Coastguard Worker self.assertIs(torch.float32, s(t, "t.to('cpu', dtype=torch.float32)").dtype) 7502*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t.device, s(t, "t.to(torch.float32)").device) 7503*da0073e9SAndroid Build Coastguard Worker self.assertIs(torch.float32, s(t, "t.to(dtype=torch.float32)").dtype) 7504*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t.data_ptr(), s(t, "t.to('cpu')").data_ptr()) 7505*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t.data_ptr(), s(t, "t.to(dtype=t.dtype, device=t.device, copy=False)").data_ptr()) 7506*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t.data_ptr(), s(t, "t.to('cpu', copy=False)").data_ptr()) 7507*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(t.data_ptr(), s(t, "t.to('cpu', copy=True)").data_ptr()) 7508*da0073e9SAndroid Build Coastguard Worker 7509*da0073e9SAndroid Build Coastguard Worker a = torch.tensor(5) 7510*da0073e9SAndroid Build Coastguard Worker if torch.cuda.is_available(): 7511*da0073e9SAndroid Build Coastguard Worker for non_blocking in [True, False]: 7512*da0073e9SAndroid Build Coastguard Worker for cuda in ['cuda', 'cuda:0' if torch.cuda.device_count() == 1 else 'cuda:1']: 7513*da0073e9SAndroid Build Coastguard Worker b = torch.tensor(5., device=cuda) 7514*da0073e9SAndroid Build Coastguard Worker test_copy_behavior(b, non_blocking) 7515*da0073e9SAndroid Build Coastguard Worker self.assertEqual(b.device, s(b, "t.to(cuda, non_blocking=non_blocking).device", cuda=cuda)) 7516*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.device, s(b, "t.to('cpu', non_blocking=non_blocking).device")) 7517*da0073e9SAndroid Build Coastguard Worker self.assertEqual(b.device, s(b, "t.to(cuda, non_blocking=non_blocking).device", cuda=cuda)) 7518*da0073e9SAndroid Build Coastguard Worker self.assertIs(torch.int32, s(b, "t.to('cpu', dtype=torch.int32, non_blocking=non_blocking)").dtype) 7519*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.device, s(b, "t.to('cpu', dtype=torch.int32, non_blocking=non_blocking)").device) 7520*da0073e9SAndroid Build Coastguard Worker self.assertIs(torch.int32, s(b, "t.to(dtype=torch.int32)").dtype) 7521*da0073e9SAndroid Build Coastguard Worker self.assertEqual(b.device, s(b, "t.to(dtype=torch.int32)").device) 7522*da0073e9SAndroid Build Coastguard Worker 7523*da0073e9SAndroid Build Coastguard Worker # Test AD: aten::to(Tensor self, int dtype, bool non_blocking, bool copy) -> Tensor 7524*da0073e9SAndroid Build Coastguard Worker t = torch.tensor(5).float().requires_grad_() 7525*da0073e9SAndroid Build Coastguard Worker out_ref = t.to(torch.float32) 7526*da0073e9SAndroid Build Coastguard Worker out = s(t, "t.to(torch.float32)") 7527*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_ref, out) 7528*da0073e9SAndroid Build Coastguard Worker 7529*da0073e9SAndroid Build Coastguard Worker grad_ref = torch.autograd.grad(out_ref.sum(), t) 7530*da0073e9SAndroid Build Coastguard Worker grad = torch.autograd.grad(out.sum(), t) 7531*da0073e9SAndroid Build Coastguard Worker self.assertEqual(grad_ref, grad) 7532*da0073e9SAndroid Build Coastguard Worker 7533*da0073e9SAndroid Build Coastguard Worker # Test AD: aten::to(Tensor self, Device? device, int? dtype, bool non_blocking, bool copy) -> Tensor 7534*da0073e9SAndroid Build Coastguard Worker out_ref = t.to('cpu') 7535*da0073e9SAndroid Build Coastguard Worker out = s(t, "t.to('cpu')") 7536*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_ref, out) 7537*da0073e9SAndroid Build Coastguard Worker 7538*da0073e9SAndroid Build Coastguard Worker grad_ref = torch.autograd.grad(out_ref.sum(), t) 7539*da0073e9SAndroid Build Coastguard Worker grad = torch.autograd.grad(out.sum(), t) 7540*da0073e9SAndroid Build Coastguard Worker self.assertEqual(grad_ref, grad) 7541*da0073e9SAndroid Build Coastguard Worker 7542*da0073e9SAndroid Build Coastguard Worker # Test AD: aten::to(Tensor self, Tensor other, bool non_blocking, bool copy) -> Tensor 7543*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 7544*da0073e9SAndroid Build Coastguard Worker def func2(t, t_ref): 7545*da0073e9SAndroid Build Coastguard Worker return t.to(t_ref) 7546*da0073e9SAndroid Build Coastguard Worker 7547*da0073e9SAndroid Build Coastguard Worker with disable_autodiff_subgraph_inlining(): 7548*da0073e9SAndroid Build Coastguard Worker t_ref = torch.tensor(4).double() 7549*da0073e9SAndroid Build Coastguard Worker out_ref = t.to(t_ref) 7550*da0073e9SAndroid Build Coastguard Worker out = func2(t, t_ref) 7551*da0073e9SAndroid Build Coastguard Worker grad_ref = torch.autograd.grad(out_ref.sum(), t) 7552*da0073e9SAndroid Build Coastguard Worker grad = torch.autograd.grad(out.sum(), t) 7553*da0073e9SAndroid Build Coastguard Worker self.assertEqual(grad_ref, grad) 7554*da0073e9SAndroid Build Coastguard Worker 7555*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not RUN_CUDA, "No CUDA") 7556*da0073e9SAndroid Build Coastguard Worker def test_tensor_number_math_cuda(self): 7557*da0073e9SAndroid Build Coastguard Worker self._test_tensor_number_math(device='cuda') 7558*da0073e9SAndroid Build Coastguard Worker 7559*da0073e9SAndroid Build Coastguard Worker def test_not(self): 7560*da0073e9SAndroid Build Coastguard Worker # test not operator in python 7561*da0073e9SAndroid Build Coastguard Worker # TODO: add more tests when bool conversions ready 7562*da0073e9SAndroid Build Coastguard Worker def test_not_op(a): 7563*da0073e9SAndroid Build Coastguard Worker return not bool(a > 1) 7564*da0073e9SAndroid Build Coastguard Worker 7565*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_not_op, (torch.tensor(2), ), optimize=True) 7566*da0073e9SAndroid Build Coastguard Worker 7567*da0073e9SAndroid Build Coastguard Worker def test_is_isnot(self): 7568*da0073e9SAndroid Build Coastguard Worker # test is and is not operator in python 7569*da0073e9SAndroid Build Coastguard Worker template = dedent(''' 7570*da0073e9SAndroid Build Coastguard Worker def func(): 7571*da0073e9SAndroid Build Coastguard Worker # type: () -> bool 7572*da0073e9SAndroid Build Coastguard Worker return {lhs} {op} {rhs} 7573*da0073e9SAndroid Build Coastguard Worker ''') 7574*da0073e9SAndroid Build Coastguard Worker 7575*da0073e9SAndroid Build Coastguard Worker def test(op, args): 7576*da0073e9SAndroid Build Coastguard Worker code = template.format(lhs=args[0], rhs=args[1], op=op) 7577*da0073e9SAndroid Build Coastguard Worker scope = {} 7578*da0073e9SAndroid Build Coastguard Worker execWrapper(code, globals(), scope) 7579*da0073e9SAndroid Build Coastguard Worker cu = torch.jit.CompilationUnit(code) 7580*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 7581*da0073e9SAndroid Build Coastguard Worker cu.func(), 7582*da0073e9SAndroid Build Coastguard Worker scope['func'](), 7583*da0073e9SAndroid Build Coastguard Worker msg=f"Failed with op: {op}, lhs: {args[0]}, rhs: {args[1]}" 7584*da0073e9SAndroid Build Coastguard Worker ) 7585*da0073e9SAndroid Build Coastguard Worker 7586*da0073e9SAndroid Build Coastguard Worker ops = ['is', 'is not'] 7587*da0073e9SAndroid Build Coastguard Worker type_literals = [True, False, None, [1, 1], 1, 2, .5, 1.5] 7588*da0073e9SAndroid Build Coastguard Worker 7589*da0073e9SAndroid Build Coastguard Worker # do literals product to try any types combinations 7590*da0073e9SAndroid Build Coastguard Worker for op, lhs, rhs in product(ops, type_literals, type_literals): 7591*da0073e9SAndroid Build Coastguard Worker test(op, [lhs, rhs]) 7592*da0073e9SAndroid Build Coastguard Worker 7593*da0073e9SAndroid Build Coastguard Worker def test_isinstance_refinement(self): 7594*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 7595*da0073e9SAndroid Build Coastguard Worker def foo(a): 7596*da0073e9SAndroid Build Coastguard Worker # type: (Optional[int]) -> int 7597*da0073e9SAndroid Build Coastguard Worker if isinstance(a, int): 7598*da0073e9SAndroid Build Coastguard Worker return a + 3 7599*da0073e9SAndroid Build Coastguard Worker else: 7600*da0073e9SAndroid Build Coastguard Worker return 4 7601*da0073e9SAndroid Build Coastguard Worker self.assertEqual(foo(4), 7) 7602*da0073e9SAndroid Build Coastguard Worker self.assertEqual(foo(None), 4) 7603*da0073e9SAndroid Build Coastguard Worker 7604*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 7605*da0073e9SAndroid Build Coastguard Worker def foo2(a, b): 7606*da0073e9SAndroid Build Coastguard Worker # type: (Optional[int], Optional[int]) -> int 7607*da0073e9SAndroid Build Coastguard Worker if not isinstance(a, int) or not isinstance(b, int): 7608*da0073e9SAndroid Build Coastguard Worker return 0 7609*da0073e9SAndroid Build Coastguard Worker else: 7610*da0073e9SAndroid Build Coastguard Worker return a + b 7611*da0073e9SAndroid Build Coastguard Worker self.assertEqual(foo2(3, 4), 7) 7612*da0073e9SAndroid Build Coastguard Worker self.assertEqual(foo2(None, 4), 0) 7613*da0073e9SAndroid Build Coastguard Worker self.assertEqual(foo2(4, None), 0) 7614*da0073e9SAndroid Build Coastguard Worker 7615*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 7616*da0073e9SAndroid Build Coastguard Worker def any_refinement(a, b): 7617*da0073e9SAndroid Build Coastguard Worker # type: (Any, Any) -> int 7618*da0073e9SAndroid Build Coastguard Worker if isinstance(a, int) and isinstance(b, int): 7619*da0073e9SAndroid Build Coastguard Worker return a + b 7620*da0073e9SAndroid Build Coastguard Worker return 0 7621*da0073e9SAndroid Build Coastguard Worker 7622*da0073e9SAndroid Build Coastguard Worker self.assertEqual(any_refinement(3, 4), 7) 7623*da0073e9SAndroid Build Coastguard Worker self.assertEqual(any_refinement(3, "hi"), 0) 7624*da0073e9SAndroid Build Coastguard Worker 7625*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 7626*da0073e9SAndroid Build Coastguard Worker def any_refinement2(a): 7627*da0073e9SAndroid Build Coastguard Worker # type: (Any) -> Tensor 7628*da0073e9SAndroid Build Coastguard Worker if isinstance(a, Tensor): 7629*da0073e9SAndroid Build Coastguard Worker return a 7630*da0073e9SAndroid Build Coastguard Worker return torch.tensor(3) 7631*da0073e9SAndroid Build Coastguard Worker 7632*da0073e9SAndroid Build Coastguard Worker self.assertEqual(any_refinement2(3), torch.tensor(3)) 7633*da0073e9SAndroid Build Coastguard Worker self.assertEqual(any_refinement2(torch.tensor(5)), torch.tensor(5)) 7634*da0073e9SAndroid Build Coastguard Worker 7635*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(GRAPH_EXECUTOR == ProfilingMode.LEGACY, "bug persists in deprecated executor") 7636*da0073e9SAndroid Build Coastguard Worker def test_unspecialized_any_binding(self): 7637*da0073e9SAndroid Build Coastguard Worker # any binding will infer the type, if it infers 7638*da0073e9SAndroid Build Coastguard Worker # a specialized tensor type `x` Dict type will fail isinstance check 7639*da0073e9SAndroid Build Coastguard Worker 7640*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 7641*da0073e9SAndroid Build Coastguard Worker def foo(x: Any): 7642*da0073e9SAndroid Build Coastguard Worker assert isinstance(x, Dict[str, torch.Tensor]) 7643*da0073e9SAndroid Build Coastguard Worker 7644*da0073e9SAndroid Build Coastguard Worker foo({"1": torch.tensor(3)}) 7645*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(Exception): 7646*da0073e9SAndroid Build Coastguard Worker foo(2) 7647*da0073e9SAndroid Build Coastguard Worker 7648*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("Not a TorchDynamo suitable test") 7649*da0073e9SAndroid Build Coastguard Worker def test_isinstance(self): 7650*da0073e9SAndroid Build Coastguard Worker # test isinstance operator for static type checking 7651*da0073e9SAndroid Build Coastguard Worker template = dedent(''' 7652*da0073e9SAndroid Build Coastguard Worker def func(x): 7653*da0073e9SAndroid Build Coastguard Worker # type: ({type_hint}) -> bool 7654*da0073e9SAndroid Build Coastguard Worker return isinstance(x, {typ}) 7655*da0073e9SAndroid Build Coastguard Worker ''') 7656*da0073e9SAndroid Build Coastguard Worker 7657*da0073e9SAndroid Build Coastguard Worker def test(inp, typ, type_hint): 7658*da0073e9SAndroid Build Coastguard Worker code = template.format(typ=typ, type_hint=type_hint) 7659*da0073e9SAndroid Build Coastguard Worker scope = {} 7660*da0073e9SAndroid Build Coastguard Worker execWrapper(code, globals(), scope) 7661*da0073e9SAndroid Build Coastguard Worker cu = torch.jit.CompilationUnit(code) 7662*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 7663*da0073e9SAndroid Build Coastguard Worker cu.func(inp), 7664*da0073e9SAndroid Build Coastguard Worker scope['func'](inp), 7665*da0073e9SAndroid Build Coastguard Worker msg=f"Failed with typ: {typ}" 7666*da0073e9SAndroid Build Coastguard Worker ) 7667*da0073e9SAndroid Build Coastguard Worker 7668*da0073e9SAndroid Build Coastguard Worker inputs = [True, 1, 1.0, torch.tensor(1), [1, 2], (1.0,), [1, 2], 1] 7669*da0073e9SAndroid Build Coastguard Worker type_literals = ['bool', 'int', 'float', 'torch.Tensor', 'list', 'tuple', 7670*da0073e9SAndroid Build Coastguard Worker '(list, tuple)', '(int, float, bool)'] 7671*da0073e9SAndroid Build Coastguard Worker type_annotations = ['bool', 'int', 'float', 'Tensor', 'List[int]', 'Tuple[float]', 7672*da0073e9SAndroid Build Coastguard Worker 'List[int]', 'int'] 7673*da0073e9SAndroid Build Coastguard Worker 7674*da0073e9SAndroid Build Coastguard Worker # do zipping to try different types 7675*da0073e9SAndroid Build Coastguard Worker for inp, typ, type_hint in zip(inputs, type_literals, type_annotations): 7676*da0073e9SAndroid Build Coastguard Worker test(inp, typ, type_hint) 7677*da0073e9SAndroid Build Coastguard Worker 7678*da0073e9SAndroid Build Coastguard Worker # test optional isinstance check 7679*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 7680*da0073e9SAndroid Build Coastguard Worker def opt_func(x): 7681*da0073e9SAndroid Build Coastguard Worker # type: (Optional[int]) -> bool 7682*da0073e9SAndroid Build Coastguard Worker return isinstance(x, int) 7683*da0073e9SAndroid Build Coastguard Worker self.assertTrue(opt_func(3)) 7684*da0073e9SAndroid Build Coastguard Worker self.assertFalse(opt_func(None)) 7685*da0073e9SAndroid Build Coastguard Worker 7686*da0073e9SAndroid Build Coastguard Worker def test_dropout_eval(self): 7687*da0073e9SAndroid Build Coastguard Worker class ScriptedConv2d(torch.jit.ScriptModule): 7688*da0073e9SAndroid Build Coastguard Worker def __init__(self, in_channels, out_channels, **kwargs): 7689*da0073e9SAndroid Build Coastguard Worker super().__init__() 7690*da0073e9SAndroid Build Coastguard Worker self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) 7691*da0073e9SAndroid Build Coastguard Worker self.bn = nn.BatchNorm2d(out_channels, eps=0.001) 7692*da0073e9SAndroid Build Coastguard Worker 7693*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 7694*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 7695*da0073e9SAndroid Build Coastguard Worker x = self.conv(x) 7696*da0073e9SAndroid Build Coastguard Worker x = self.bn(x) 7697*da0073e9SAndroid Build Coastguard Worker return F.relu(x, inplace=True) 7698*da0073e9SAndroid Build Coastguard Worker 7699*da0073e9SAndroid Build Coastguard Worker class ScriptMod(torch.jit.ScriptModule): 7700*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 7701*da0073e9SAndroid Build Coastguard Worker super().__init__() 7702*da0073e9SAndroid Build Coastguard Worker self.Conv2d_1a_3x3 = ScriptedConv2d(3, 32, kernel_size=3, stride=2) 7703*da0073e9SAndroid Build Coastguard Worker 7704*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 7705*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 7706*da0073e9SAndroid Build Coastguard Worker x = self.Conv2d_1a_3x3(x) 7707*da0073e9SAndroid Build Coastguard Worker return F.dropout(x, training=self.training) 7708*da0073e9SAndroid Build Coastguard Worker 7709*da0073e9SAndroid Build Coastguard Worker class EagerConv2d(torch.nn.Module): 7710*da0073e9SAndroid Build Coastguard Worker def __init__(self, in_channels, out_channels, **kwargs): 7711*da0073e9SAndroid Build Coastguard Worker super().__init__() 7712*da0073e9SAndroid Build Coastguard Worker self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) 7713*da0073e9SAndroid Build Coastguard Worker self.bn = nn.BatchNorm2d(out_channels, eps=0.001) 7714*da0073e9SAndroid Build Coastguard Worker 7715*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 7716*da0073e9SAndroid Build Coastguard Worker x = self.conv(x) 7717*da0073e9SAndroid Build Coastguard Worker x = self.bn(x) 7718*da0073e9SAndroid Build Coastguard Worker return F.relu(x, inplace=True) 7719*da0073e9SAndroid Build Coastguard Worker 7720*da0073e9SAndroid Build Coastguard Worker class EagerMod(torch.nn.Module): 7721*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 7722*da0073e9SAndroid Build Coastguard Worker super().__init__() 7723*da0073e9SAndroid Build Coastguard Worker self.Conv2d_1a_3x3 = EagerConv2d(3, 32, kernel_size=3, stride=2) 7724*da0073e9SAndroid Build Coastguard Worker 7725*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 7726*da0073e9SAndroid Build Coastguard Worker x = self.Conv2d_1a_3x3(x) 7727*da0073e9SAndroid Build Coastguard Worker return F.dropout(x, training=self.training) 7728*da0073e9SAndroid Build Coastguard Worker 7729*da0073e9SAndroid Build Coastguard Worker script_input = torch.rand(4, 3, 299, 299) 7730*da0073e9SAndroid Build Coastguard Worker eager_input = script_input.clone() 7731*da0073e9SAndroid Build Coastguard Worker 7732*da0073e9SAndroid Build Coastguard Worker with freeze_rng_state(): 7733*da0073e9SAndroid Build Coastguard Worker script_mod = ScriptMod() 7734*da0073e9SAndroid Build Coastguard Worker script_mod.eval() 7735*da0073e9SAndroid Build Coastguard Worker script_output = script_mod(script_input) 7736*da0073e9SAndroid Build Coastguard Worker 7737*da0073e9SAndroid Build Coastguard Worker with freeze_rng_state(): 7738*da0073e9SAndroid Build Coastguard Worker eager_mod = EagerMod() 7739*da0073e9SAndroid Build Coastguard Worker eager_mod.eval() 7740*da0073e9SAndroid Build Coastguard Worker eager_output = eager_mod(eager_input) 7741*da0073e9SAndroid Build Coastguard Worker 7742*da0073e9SAndroid Build Coastguard Worker self.assertEqual(script_output, eager_output) 7743*da0073e9SAndroid Build Coastguard Worker 7744*da0073e9SAndroid Build Coastguard Worker with freeze_rng_state(): 7745*da0073e9SAndroid Build Coastguard Worker script_mod = ScriptMod() 7746*da0073e9SAndroid Build Coastguard Worker script_mod.train() 7747*da0073e9SAndroid Build Coastguard Worker script_output = script_mod(script_input) 7748*da0073e9SAndroid Build Coastguard Worker 7749*da0073e9SAndroid Build Coastguard Worker with freeze_rng_state(): 7750*da0073e9SAndroid Build Coastguard Worker eager_mod = EagerMod() 7751*da0073e9SAndroid Build Coastguard Worker eager_mod.train() 7752*da0073e9SAndroid Build Coastguard Worker eager_output = eager_mod(eager_input) 7753*da0073e9SAndroid Build Coastguard Worker 7754*da0073e9SAndroid Build Coastguard Worker self.assertEqual(script_output, eager_output) 7755*da0073e9SAndroid Build Coastguard Worker 7756*da0073e9SAndroid Build Coastguard Worker def test_nested_breaks(self): 7757*da0073e9SAndroid Build Coastguard Worker def no_bool_loop_outputs(g): 7758*da0073e9SAndroid Build Coastguard Worker # testing that the "did exit" transform values are not loop block 7759*da0073e9SAndroid Build Coastguard Worker # outputs (and thus not affecting one loop from another) 7760*da0073e9SAndroid Build Coastguard Worker loops = g.findAllNodes("prim::Loop") 7761*da0073e9SAndroid Build Coastguard Worker for loop in loops: 7762*da0073e9SAndroid Build Coastguard Worker for out in loop.outputs(): 7763*da0073e9SAndroid Build Coastguard Worker self.assertTrue(out.type() != BoolType.get()) 7764*da0073e9SAndroid Build Coastguard Worker 7765*da0073e9SAndroid Build Coastguard Worker def test(y): 7766*da0073e9SAndroid Build Coastguard Worker # type: (int) 7767*da0073e9SAndroid Build Coastguard Worker ret = 0 7768*da0073e9SAndroid Build Coastguard Worker tensor = torch.tensor(0) 7769*da0073e9SAndroid Build Coastguard Worker while int(tensor.add_(1)) < 4: 7770*da0073e9SAndroid Build Coastguard Worker if y == 1: 7771*da0073e9SAndroid Build Coastguard Worker continue 7772*da0073e9SAndroid Build Coastguard Worker for i in range(y): 7773*da0073e9SAndroid Build Coastguard Worker continue 7774*da0073e9SAndroid Build Coastguard Worker ret += 1 7775*da0073e9SAndroid Build Coastguard Worker ret += 1 7776*da0073e9SAndroid Build Coastguard Worker return ret, int(tensor) 7777*da0073e9SAndroid Build Coastguard Worker 7778*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.jit.script(test)(1), test(1)) 7779*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.jit.script(test)(2), test(2)) 7780*da0073e9SAndroid Build Coastguard Worker no_bool_loop_outputs(torch.jit.script(test).graph) 7781*da0073e9SAndroid Build Coastguard Worker 7782*da0073e9SAndroid Build Coastguard Worker def foo(): 7783*da0073e9SAndroid Build Coastguard Worker y = torch.tensor(0) 7784*da0073e9SAndroid Build Coastguard Worker z = 0 7785*da0073e9SAndroid Build Coastguard Worker while int(y.add_(1)) < 20: 7786*da0073e9SAndroid Build Coastguard Worker if int(y) < 10: 7787*da0073e9SAndroid Build Coastguard Worker for i in range(6): 7788*da0073e9SAndroid Build Coastguard Worker if i == 3: 7789*da0073e9SAndroid Build Coastguard Worker continue 7790*da0073e9SAndroid Build Coastguard Worker else: 7791*da0073e9SAndroid Build Coastguard Worker if i > 3: 7792*da0073e9SAndroid Build Coastguard Worker break 7793*da0073e9SAndroid Build Coastguard Worker z += 2 7794*da0073e9SAndroid Build Coastguard Worker if int(y) == 18: 7795*da0073e9SAndroid Build Coastguard Worker break 7796*da0073e9SAndroid Build Coastguard Worker if int(y) == 15: 7797*da0073e9SAndroid Build Coastguard Worker continue 7798*da0073e9SAndroid Build Coastguard Worker z += 1 7799*da0073e9SAndroid Build Coastguard Worker return int(y), z 7800*da0073e9SAndroid Build Coastguard Worker 7801*da0073e9SAndroid Build Coastguard Worker no_bool_loop_outputs(torch.jit.script(foo).graph) 7802*da0073e9SAndroid Build Coastguard Worker self.checkScript(foo, ()) 7803*da0073e9SAndroid Build Coastguard Worker 7804*da0073e9SAndroid Build Coastguard Worker def test_nested_two(): 7805*da0073e9SAndroid Build Coastguard Worker i = 0 7806*da0073e9SAndroid Build Coastguard Worker k = 0 7807*da0073e9SAndroid Build Coastguard Worker while i < 5: 7808*da0073e9SAndroid Build Coastguard Worker for j in range(5): 7809*da0073e9SAndroid Build Coastguard Worker k += 1 7810*da0073e9SAndroid Build Coastguard Worker if j == 3: 7811*da0073e9SAndroid Build Coastguard Worker continue 7812*da0073e9SAndroid Build Coastguard Worker i += 1 7813*da0073e9SAndroid Build Coastguard Worker k += 1 7814*da0073e9SAndroid Build Coastguard Worker if i == 4: 7815*da0073e9SAndroid Build Coastguard Worker break 7816*da0073e9SAndroid Build Coastguard Worker return i, k 7817*da0073e9SAndroid Build Coastguard Worker 7818*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_nested_two, ()) 7819*da0073e9SAndroid Build Coastguard Worker no_bool_loop_outputs(torch.jit.script(test_nested_two).graph) 7820*da0073e9SAndroid Build Coastguard Worker 7821*da0073e9SAndroid Build Coastguard Worker def test_breaks_continues(self): 7822*da0073e9SAndroid Build Coastguard Worker def foo_continue(cond): 7823*da0073e9SAndroid Build Coastguard Worker # type: (int) 7824*da0073e9SAndroid Build Coastguard Worker j = 1 7825*da0073e9SAndroid Build Coastguard Worker for i in range(5): 7826*da0073e9SAndroid Build Coastguard Worker if i == cond: 7827*da0073e9SAndroid Build Coastguard Worker continue 7828*da0073e9SAndroid Build Coastguard Worker j += 1 7829*da0073e9SAndroid Build Coastguard Worker return j 7830*da0073e9SAndroid Build Coastguard Worker 7831*da0073e9SAndroid Build Coastguard Worker def foo_break(cond): 7832*da0073e9SAndroid Build Coastguard Worker # type: (int) 7833*da0073e9SAndroid Build Coastguard Worker j = 1 7834*da0073e9SAndroid Build Coastguard Worker for i in range(5): 7835*da0073e9SAndroid Build Coastguard Worker if i == cond: 7836*da0073e9SAndroid Build Coastguard Worker break 7837*da0073e9SAndroid Build Coastguard Worker j += 1 7838*da0073e9SAndroid Build Coastguard Worker return j 7839*da0073e9SAndroid Build Coastguard Worker 7840*da0073e9SAndroid Build Coastguard Worker for i in range(1, 4): 7841*da0073e9SAndroid Build Coastguard Worker self.checkScript(foo_continue, (i,)) 7842*da0073e9SAndroid Build Coastguard Worker self.checkScript(foo_break, (i,)) 7843*da0073e9SAndroid Build Coastguard Worker 7844*da0073e9SAndroid Build Coastguard Worker def test_refine_outside_loop(): 7845*da0073e9SAndroid Build Coastguard Worker if 1 == 1: 7846*da0073e9SAndroid Build Coastguard Worker x = None 7847*da0073e9SAndroid Build Coastguard Worker else: 7848*da0073e9SAndroid Build Coastguard Worker x = 1 7849*da0073e9SAndroid Build Coastguard Worker i = 0 7850*da0073e9SAndroid Build Coastguard Worker j = 0 7851*da0073e9SAndroid Build Coastguard Worker while (x is None or torch.jit._unwrap_optional(x) > 3): 7852*da0073e9SAndroid Build Coastguard Worker if i < 3: 7853*da0073e9SAndroid Build Coastguard Worker if i < 3: 7854*da0073e9SAndroid Build Coastguard Worker x = torch.jit.annotate(Optional[int], None) 7855*da0073e9SAndroid Build Coastguard Worker i += 1 7856*da0073e9SAndroid Build Coastguard Worker continue 7857*da0073e9SAndroid Build Coastguard Worker x = 1 7858*da0073e9SAndroid Build Coastguard Worker else: 7859*da0073e9SAndroid Build Coastguard Worker x = 1 if x is None else x 7860*da0073e9SAndroid Build Coastguard Worker x = x + 1 7861*da0073e9SAndroid Build Coastguard Worker j = x + x 7862*da0073e9SAndroid Build Coastguard Worker 7863*da0073e9SAndroid Build Coastguard Worker return x, j 7864*da0073e9SAndroid Build Coastguard Worker 7865*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_refine_outside_loop, ()) 7866*da0073e9SAndroid Build Coastguard Worker 7867*da0073e9SAndroid Build Coastguard Worker def assign_after_break(y): 7868*da0073e9SAndroid Build Coastguard Worker # type: (int) 7869*da0073e9SAndroid Build Coastguard Worker x = 0 7870*da0073e9SAndroid Build Coastguard Worker for i in range(y): 7871*da0073e9SAndroid Build Coastguard Worker x = y * 2 + i 7872*da0073e9SAndroid Build Coastguard Worker break 7873*da0073e9SAndroid Build Coastguard Worker x = 4 7874*da0073e9SAndroid Build Coastguard Worker return x 7875*da0073e9SAndroid Build Coastguard Worker 7876*da0073e9SAndroid Build Coastguard Worker self.checkScript(assign_after_break, (1,)) 7877*da0073e9SAndroid Build Coastguard Worker self.checkScript(assign_after_break, (2,)) 7878*da0073e9SAndroid Build Coastguard Worker self.checkScript(assign_after_break, (3,)) 7879*da0073e9SAndroid Build Coastguard Worker 7880*da0073e9SAndroid Build Coastguard Worker def assign_after_break_nested(y): 7881*da0073e9SAndroid Build Coastguard Worker # type: (int) 7882*da0073e9SAndroid Build Coastguard Worker x = 0 7883*da0073e9SAndroid Build Coastguard Worker for i in range(y): 7884*da0073e9SAndroid Build Coastguard Worker if y == 1: 7885*da0073e9SAndroid Build Coastguard Worker x = 5 7886*da0073e9SAndroid Build Coastguard Worker break 7887*da0073e9SAndroid Build Coastguard Worker assert 1 == 2 7888*da0073e9SAndroid Build Coastguard Worker else: 7889*da0073e9SAndroid Build Coastguard Worker x = x + 1 7890*da0073e9SAndroid Build Coastguard Worker break 7891*da0073e9SAndroid Build Coastguard Worker assert 1 == 2 7892*da0073e9SAndroid Build Coastguard Worker x = -30 7893*da0073e9SAndroid Build Coastguard Worker assert 1 == 2 7894*da0073e9SAndroid Build Coastguard Worker return x 7895*da0073e9SAndroid Build Coastguard Worker 7896*da0073e9SAndroid Build Coastguard Worker self.checkScript(assign_after_break_nested, (1,)) 7897*da0073e9SAndroid Build Coastguard Worker self.checkScript(assign_after_break_nested, (2,)) 7898*da0073e9SAndroid Build Coastguard Worker self.checkScript(assign_after_break_nested, (3,)) 7899*da0073e9SAndroid Build Coastguard Worker 7900*da0073e9SAndroid Build Coastguard Worker def may_break(y): 7901*da0073e9SAndroid Build Coastguard Worker # type: (int) 7902*da0073e9SAndroid Build Coastguard Worker x = 0 7903*da0073e9SAndroid Build Coastguard Worker for i in range(y): 7904*da0073e9SAndroid Build Coastguard Worker if y == 1: 7905*da0073e9SAndroid Build Coastguard Worker x = 5 7906*da0073e9SAndroid Build Coastguard Worker else: 7907*da0073e9SAndroid Build Coastguard Worker x = x + 1 7908*da0073e9SAndroid Build Coastguard Worker break 7909*da0073e9SAndroid Build Coastguard Worker x = -30 7910*da0073e9SAndroid Build Coastguard Worker return x 7911*da0073e9SAndroid Build Coastguard Worker 7912*da0073e9SAndroid Build Coastguard Worker self.checkScript(may_break, (1,)) 7913*da0073e9SAndroid Build Coastguard Worker self.checkScript(may_break, (2,)) 7914*da0073e9SAndroid Build Coastguard Worker self.checkScript(may_break, (3,)) 7915*da0073e9SAndroid Build Coastguard Worker 7916*da0073e9SAndroid Build Coastguard Worker def test(x, y): 7917*da0073e9SAndroid Build Coastguard Worker # type: (int, int) 7918*da0073e9SAndroid Build Coastguard Worker a = 1 7919*da0073e9SAndroid Build Coastguard Worker while (x > 0): 7920*da0073e9SAndroid Build Coastguard Worker if y == 3: 7921*da0073e9SAndroid Build Coastguard Worker for i in range(y): 7922*da0073e9SAndroid Build Coastguard Worker a += (1 % (i + 1)) 7923*da0073e9SAndroid Build Coastguard Worker x -= 1 7924*da0073e9SAndroid Build Coastguard Worker if x == 3: 7925*da0073e9SAndroid Build Coastguard Worker a = x * 3 7926*da0073e9SAndroid Build Coastguard Worker break 7927*da0073e9SAndroid Build Coastguard Worker if x < 3: 7928*da0073e9SAndroid Build Coastguard Worker if x == 1: 7929*da0073e9SAndroid Build Coastguard Worker a -= 2 7930*da0073e9SAndroid Build Coastguard Worker x -= 1 7931*da0073e9SAndroid Build Coastguard Worker break 7932*da0073e9SAndroid Build Coastguard Worker a -= 1 7933*da0073e9SAndroid Build Coastguard Worker x -= 3 7934*da0073e9SAndroid Build Coastguard Worker return a, x 7935*da0073e9SAndroid Build Coastguard Worker 7936*da0073e9SAndroid Build Coastguard Worker self.checkScript(test, (10, 3)) 7937*da0073e9SAndroid Build Coastguard Worker self.checkScript(test, (10, 2)) 7938*da0073e9SAndroid Build Coastguard Worker self.checkScript(test, (3, 2)) 7939*da0073e9SAndroid Build Coastguard Worker self.checkScript(test, (5, 3)) 7940*da0073e9SAndroid Build Coastguard Worker self.checkScript(test, (2, 3)) 7941*da0073e9SAndroid Build Coastguard Worker 7942*da0073e9SAndroid Build Coastguard Worker def test_delete_after_break(x): 7943*da0073e9SAndroid Build Coastguard Worker # type: (int) 7944*da0073e9SAndroid Build Coastguard Worker a = 1 7945*da0073e9SAndroid Build Coastguard Worker b = 1 7946*da0073e9SAndroid Build Coastguard Worker for i in range(x): 7947*da0073e9SAndroid Build Coastguard Worker a = i * 3 7948*da0073e9SAndroid Build Coastguard Worker break 7949*da0073e9SAndroid Build Coastguard Worker b = i * 5 7950*da0073e9SAndroid Build Coastguard Worker return a, b 7951*da0073e9SAndroid Build Coastguard Worker 7952*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_delete_after_break, (0,)) 7953*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_delete_after_break, (1,)) 7954*da0073e9SAndroid Build Coastguard Worker 7955*da0073e9SAndroid Build Coastguard Worker def test_will_break_after_guard(x): 7956*da0073e9SAndroid Build Coastguard Worker # type: (int) 7957*da0073e9SAndroid Build Coastguard Worker a = 1 7958*da0073e9SAndroid Build Coastguard Worker for i in range(x): 7959*da0073e9SAndroid Build Coastguard Worker if i == 4: 7960*da0073e9SAndroid Build Coastguard Worker a = 3 7961*da0073e9SAndroid Build Coastguard Worker break 7962*da0073e9SAndroid Build Coastguard Worker a -= 1 7963*da0073e9SAndroid Build Coastguard Worker break 7964*da0073e9SAndroid Build Coastguard Worker assert 1 == 2 7965*da0073e9SAndroid Build Coastguard Worker a -= -100 7966*da0073e9SAndroid Build Coastguard Worker return a 7967*da0073e9SAndroid Build Coastguard Worker 7968*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_will_break_after_guard, (0,)) 7969*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_will_break_after_guard, (2,)) 7970*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_will_break_after_guard, (4,)) 7971*da0073e9SAndroid Build Coastguard Worker 7972*da0073e9SAndroid Build Coastguard Worker def test_varexit(cond): 7973*da0073e9SAndroid Build Coastguard Worker # type: (int) 7974*da0073e9SAndroid Build Coastguard Worker m = 0 7975*da0073e9SAndroid Build Coastguard Worker for i in range(3): 7976*da0073e9SAndroid Build Coastguard Worker if cond == 2: 7977*da0073e9SAndroid Build Coastguard Worker if cond == 2: 7978*da0073e9SAndroid Build Coastguard Worker m = 2 7979*da0073e9SAndroid Build Coastguard Worker break 7980*da0073e9SAndroid Build Coastguard Worker k = 1 7981*da0073e9SAndroid Build Coastguard Worker else: 7982*da0073e9SAndroid Build Coastguard Worker k = 2 7983*da0073e9SAndroid Build Coastguard Worker m += k 7984*da0073e9SAndroid Build Coastguard Worker return m 7985*da0073e9SAndroid Build Coastguard Worker 7986*da0073e9SAndroid Build Coastguard Worker # use of k tests the pathway where we have to insert unitialized 7987*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_varexit, (3,)) 7988*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_varexit, (2,)) 7989*da0073e9SAndroid Build Coastguard Worker 7990*da0073e9SAndroid Build Coastguard Worker def test_break_true(): 7991*da0073e9SAndroid Build Coastguard Worker i = 0 7992*da0073e9SAndroid Build Coastguard Worker while True: 7993*da0073e9SAndroid Build Coastguard Worker i += 1 7994*da0073e9SAndroid Build Coastguard Worker if i == 3: 7995*da0073e9SAndroid Build Coastguard Worker break 7996*da0073e9SAndroid Build Coastguard Worker while False: 7997*da0073e9SAndroid Build Coastguard Worker i += 1 7998*da0073e9SAndroid Build Coastguard Worker return i 7999*da0073e9SAndroid Build Coastguard Worker 8000*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_break_true, ()) 8001*da0073e9SAndroid Build Coastguard Worker 8002*da0073e9SAndroid Build Coastguard Worker def test_break_continue_error(self): 8003*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Syntax"): 8004*da0073e9SAndroid Build Coastguard Worker cu = torch.jit.CompilationUnit(''' 8005*da0073e9SAndroid Build Coastguard Worker def other_func(a): 8006*da0073e9SAndroid Build Coastguard Worker break 8007*da0073e9SAndroid Build Coastguard Worker ''') 8008*da0073e9SAndroid Build Coastguard Worker 8009*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Syntax"): 8010*da0073e9SAndroid Build Coastguard Worker cu = torch.jit.CompilationUnit(''' 8011*da0073e9SAndroid Build Coastguard Worker def other_func(a): 8012*da0073e9SAndroid Build Coastguard Worker for i in range(5): 8013*da0073e9SAndroid Build Coastguard Worker def foo(): 8014*da0073e9SAndroid Build Coastguard Worker break 8015*da0073e9SAndroid Build Coastguard Worker ''') 8016*da0073e9SAndroid Build Coastguard Worker 8017*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "do not support break or continue inside"): 8018*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 8019*da0073e9SAndroid Build Coastguard Worker def foo(x): 8020*da0073e9SAndroid Build Coastguard Worker i = 0 8021*da0073e9SAndroid Build Coastguard Worker for a in (1, "2", 1.5): 8022*da0073e9SAndroid Build Coastguard Worker b = a 8023*da0073e9SAndroid Build Coastguard Worker if x: 8024*da0073e9SAndroid Build Coastguard Worker break 8025*da0073e9SAndroid Build Coastguard Worker return b 8026*da0073e9SAndroid Build Coastguard Worker 8027*da0073e9SAndroid Build Coastguard Worker def test_python_call(self): 8028*da0073e9SAndroid Build Coastguard Worker def pyfunc(a): 8029*da0073e9SAndroid Build Coastguard Worker return a * 3.0 8030*da0073e9SAndroid Build Coastguard Worker 8031*da0073e9SAndroid Build Coastguard Worker cu = torch.jit.CompilationUnit(''' 8032*da0073e9SAndroid Build Coastguard Worker def other_func(a): 8033*da0073e9SAndroid Build Coastguard Worker return a + a 8034*da0073e9SAndroid Build Coastguard Worker 8035*da0073e9SAndroid Build Coastguard Worker def test_call_python(a): 8036*da0073e9SAndroid Build Coastguard Worker b = pyfunc(a) 8037*da0073e9SAndroid Build Coastguard Worker b = other_func(b) 8038*da0073e9SAndroid Build Coastguard Worker i = 0 8039*da0073e9SAndroid Build Coastguard Worker step = 1 8040*da0073e9SAndroid Build Coastguard Worker while i < 10: 8041*da0073e9SAndroid Build Coastguard Worker b = pyfunc(b) 8042*da0073e9SAndroid Build Coastguard Worker if bool(b > 3.0): 8043*da0073e9SAndroid Build Coastguard Worker b = pyfunc(b) 8044*da0073e9SAndroid Build Coastguard Worker i = 11 8045*da0073e9SAndroid Build Coastguard Worker return b 8046*da0073e9SAndroid Build Coastguard Worker ''') 8047*da0073e9SAndroid Build Coastguard Worker inputs = self._make_scalar_vars([1], torch.float) 8048*da0073e9SAndroid Build Coastguard Worker outputs = self._make_scalar_vars([54], torch.float) 8049*da0073e9SAndroid Build Coastguard Worker 8050*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cu.test_call_python(*inputs), outputs[0]) 8051*da0073e9SAndroid Build Coastguard Worker 8052*da0073e9SAndroid Build Coastguard Worker def test_python_call_failure(self): 8053*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "undefined value pyfunc2"): 8054*da0073e9SAndroid Build Coastguard Worker def pyfunc(a): 8055*da0073e9SAndroid Build Coastguard Worker return a * 3.0 8056*da0073e9SAndroid Build Coastguard Worker 8057*da0073e9SAndroid Build Coastguard Worker cu = torch.jit.CompilationUnit(''' 8058*da0073e9SAndroid Build Coastguard Worker def other_func(a): 8059*da0073e9SAndroid Build Coastguard Worker return a + a 8060*da0073e9SAndroid Build Coastguard Worker 8061*da0073e9SAndroid Build Coastguard Worker def test_call_python(a): 8062*da0073e9SAndroid Build Coastguard Worker b = pyfunc(a) 8063*da0073e9SAndroid Build Coastguard Worker b = other_func(b) 8064*da0073e9SAndroid Build Coastguard Worker i = 0 8065*da0073e9SAndroid Build Coastguard Worker step = 1 8066*da0073e9SAndroid Build Coastguard Worker while i < 10: 8067*da0073e9SAndroid Build Coastguard Worker b = pyfunc2(b) 8068*da0073e9SAndroid Build Coastguard Worker if b > 3.0: 8069*da0073e9SAndroid Build Coastguard Worker b = pyfunc(b) 8070*da0073e9SAndroid Build Coastguard Worker i = 11 8071*da0073e9SAndroid Build Coastguard Worker return b 8072*da0073e9SAndroid Build Coastguard Worker ''') 8073*da0073e9SAndroid Build Coastguard Worker inputs = self._make_scalar_vars([1], torch.float) 8074*da0073e9SAndroid Build Coastguard Worker outputs = self._make_scalar_vars([54], torch.float) 8075*da0073e9SAndroid Build Coastguard Worker 8076*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cu.test_call_python(*inputs), outputs) 8077*da0073e9SAndroid Build Coastguard Worker 8078*da0073e9SAndroid Build Coastguard Worker def test_type_call_in_script(self): 8079*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 8080*da0073e9SAndroid Build Coastguard Worker def fn(x): 8081*da0073e9SAndroid Build Coastguard Worker return type(x) 8082*da0073e9SAndroid Build Coastguard Worker 8083*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "value of type _TensorMeta"): 8084*da0073e9SAndroid Build Coastguard Worker fn(torch.tensor(.5)) 8085*da0073e9SAndroid Build Coastguard Worker 8086*da0073e9SAndroid Build Coastguard Worker def test_python_call_annotation(self): 8087*da0073e9SAndroid Build Coastguard Worker def pyfunc(a): 8088*da0073e9SAndroid Build Coastguard Worker return a * 3.0 8089*da0073e9SAndroid Build Coastguard Worker 8090*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 8091*da0073e9SAndroid Build Coastguard Worker def foo(a): 8092*da0073e9SAndroid Build Coastguard Worker return pyfunc(a) + pyfunc(a) 8093*da0073e9SAndroid Build Coastguard Worker 8094*da0073e9SAndroid Build Coastguard Worker inputs = self._make_scalar_vars([1], torch.float) 8095*da0073e9SAndroid Build Coastguard Worker outputs = self._make_scalar_vars([6], torch.float) 8096*da0073e9SAndroid Build Coastguard Worker self.assertEqual(foo(*inputs), outputs[0]) 8097*da0073e9SAndroid Build Coastguard Worker 8098*da0073e9SAndroid Build Coastguard Worker def test_python_call_annoytation_failure(self): 8099*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "undefined value pyfunc2"): 8100*da0073e9SAndroid Build Coastguard Worker def pyfunc(a): 8101*da0073e9SAndroid Build Coastguard Worker return a * 3.0 8102*da0073e9SAndroid Build Coastguard Worker 8103*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 8104*da0073e9SAndroid Build Coastguard Worker def foo(a): 8105*da0073e9SAndroid Build Coastguard Worker return pyfunc2(a) + pyfunc(a) # noqa: F821 8106*da0073e9SAndroid Build Coastguard Worker 8107*da0073e9SAndroid Build Coastguard Worker inputs = self._make_scalar_vars([1], torch.float) 8108*da0073e9SAndroid Build Coastguard Worker outputs = self._make_scalar_vars([6], torch.float) 8109*da0073e9SAndroid Build Coastguard Worker 8110*da0073e9SAndroid Build Coastguard Worker self.assertEqual(foo(*inputs), outputs[0]) 8111*da0073e9SAndroid Build Coastguard Worker 8112*da0073e9SAndroid Build Coastguard Worker def test_desugar_module(self): 8113*da0073e9SAndroid Build Coastguard Worker import torch.nn.functional as F 8114*da0073e9SAndroid Build Coastguard Worker 8115*da0073e9SAndroid Build Coastguard Worker def fn(x, slope): 8116*da0073e9SAndroid Build Coastguard Worker a = torch.abs(x) 8117*da0073e9SAndroid Build Coastguard Worker b = torch.nn.functional.prelu(x, slope) 8118*da0073e9SAndroid Build Coastguard Worker c = F.prelu(x, slope) 8119*da0073e9SAndroid Build Coastguard Worker return a, b, c 8120*da0073e9SAndroid Build Coastguard Worker 8121*da0073e9SAndroid Build Coastguard Worker x = torch.arange(-3., 4) 8122*da0073e9SAndroid Build Coastguard Worker slope = torch.tensor([0.5]) 8123*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, [x, slope], optimize=True) 8124*da0073e9SAndroid Build Coastguard Worker 8125*da0073e9SAndroid Build Coastguard Worker def test_script_docstring(self): 8126*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 8127*da0073e9SAndroid Build Coastguard Worker def with_docstring(x): 8128*da0073e9SAndroid Build Coastguard Worker """test str""" 8129*da0073e9SAndroid Build Coastguard Worker y = x 8130*da0073e9SAndroid Build Coastguard Worker """y is the same as x""" 8131*da0073e9SAndroid Build Coastguard Worker return y 8132*da0073e9SAndroid Build Coastguard Worker self.assertEqual(with_docstring.__doc__, 'test str') 8133*da0073e9SAndroid Build Coastguard Worker 8134*da0073e9SAndroid Build Coastguard Worker def test_script_method_docstring(self): 8135*da0073e9SAndroid Build Coastguard Worker class A(torch.jit.ScriptModule): 8136*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 8137*da0073e9SAndroid Build Coastguard Worker def with_docstring(self, x): 8138*da0073e9SAndroid Build Coastguard Worker """test str""" 8139*da0073e9SAndroid Build Coastguard Worker y = x 8140*da0073e9SAndroid Build Coastguard Worker """y is the same as x""" 8141*da0073e9SAndroid Build Coastguard Worker return y 8142*da0073e9SAndroid Build Coastguard Worker a = A() 8143*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.with_docstring.__doc__, 'test str') 8144*da0073e9SAndroid Build Coastguard Worker 8145*da0073e9SAndroid Build Coastguard Worker def test_script_module(self): 8146*da0073e9SAndroid Build Coastguard Worker class M1(torch.jit.ScriptModule): 8147*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 8148*da0073e9SAndroid Build Coastguard Worker super().__init__() 8149*da0073e9SAndroid Build Coastguard Worker self.weight = nn.Parameter(torch.randn(2)) 8150*da0073e9SAndroid Build Coastguard Worker 8151*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 8152*da0073e9SAndroid Build Coastguard Worker def forward(self, thing): 8153*da0073e9SAndroid Build Coastguard Worker return self.weight + thing 8154*da0073e9SAndroid Build Coastguard Worker 8155*da0073e9SAndroid Build Coastguard Worker class PModule(nn.Module): 8156*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 8157*da0073e9SAndroid Build Coastguard Worker super().__init__() 8158*da0073e9SAndroid Build Coastguard Worker self.a = nn.Parameter(torch.randn(2, 3)) 8159*da0073e9SAndroid Build Coastguard Worker 8160*da0073e9SAndroid Build Coastguard Worker def forward(self, a): 8161*da0073e9SAndroid Build Coastguard Worker return self.a.mm(a) 8162*da0073e9SAndroid Build Coastguard Worker 8163*da0073e9SAndroid Build Coastguard Worker class M2(torch.jit.ScriptModule): 8164*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 8165*da0073e9SAndroid Build Coastguard Worker super().__init__() 8166*da0073e9SAndroid Build Coastguard Worker # test submodule 8167*da0073e9SAndroid Build Coastguard Worker self.sub = M1() 8168*da0073e9SAndroid Build Coastguard Worker self.sub2 = PModule() 8169*da0073e9SAndroid Build Coastguard Worker # test parameters 8170*da0073e9SAndroid Build Coastguard Worker self.weight = nn.Parameter(torch.randn(2, 3)) 8171*da0073e9SAndroid Build Coastguard Worker self.bias = nn.Parameter(torch.randn(2)) 8172*da0073e9SAndroid Build Coastguard Worker # test defining a method from a string 8173*da0073e9SAndroid Build Coastguard Worker self.define(""" 8174*da0073e9SAndroid Build Coastguard Worker def hi(self, a): 8175*da0073e9SAndroid Build Coastguard Worker return self.weight.mm(a) 8176*da0073e9SAndroid Build Coastguard Worker """) 8177*da0073e9SAndroid Build Coastguard Worker # test script methods 8178*da0073e9SAndroid Build Coastguard Worker 8179*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 8180*da0073e9SAndroid Build Coastguard Worker def doit(self, input): 8181*da0073e9SAndroid Build Coastguard Worker # test use of parameter 8182*da0073e9SAndroid Build Coastguard Worker return self.weight.mm(input) 8183*da0073e9SAndroid Build Coastguard Worker 8184*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 8185*da0073e9SAndroid Build Coastguard Worker def doit2(self, input): 8186*da0073e9SAndroid Build Coastguard Worker return self.weight.mm(input) 8187*da0073e9SAndroid Build Coastguard Worker 8188*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 8189*da0073e9SAndroid Build Coastguard Worker def forward(self, input): 8190*da0073e9SAndroid Build Coastguard Worker a = self.doit(input) 8191*da0073e9SAndroid Build Coastguard Worker b = self.doit2(input) 8192*da0073e9SAndroid Build Coastguard Worker c = self.hi(input) 8193*da0073e9SAndroid Build Coastguard Worker d = self.sub2(input) 8194*da0073e9SAndroid Build Coastguard Worker return a + b + self.bias + self.sub(a) + c + d 8195*da0073e9SAndroid Build Coastguard Worker with torch.jit.optimized_execution(False): 8196*da0073e9SAndroid Build Coastguard Worker m2 = M2() 8197*da0073e9SAndroid Build Coastguard Worker input = torch.randn(3, 2) 8198*da0073e9SAndroid Build Coastguard Worker a = m2.weight.mm(input) 8199*da0073e9SAndroid Build Coastguard Worker b = m2.weight.mm(input) 8200*da0073e9SAndroid Build Coastguard Worker c = m2.weight.mm(input) 8201*da0073e9SAndroid Build Coastguard Worker d = m2.sub2.a.mm(input) 8202*da0073e9SAndroid Build Coastguard Worker ref = a + b + m2.bias + m2.sub.weight + a + c + d 8203*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref, m2.forward(input)) 8204*da0073e9SAndroid Build Coastguard Worker m2.weight = nn.Parameter(torch.zeros_like(m2.weight)) 8205*da0073e9SAndroid Build Coastguard Worker m2.bias = nn.Parameter(torch.zeros_like(m2.bias)) 8206*da0073e9SAndroid Build Coastguard Worker m2.sub.weight = nn.Parameter(torch.zeros_like(m2.sub.weight)) 8207*da0073e9SAndroid Build Coastguard Worker m2.sub2.a.data.zero_() 8208*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.zeros(2, 2), m2.forward(torch.randn(3, 2))) 8209*da0073e9SAndroid Build Coastguard Worker 8210*da0073e9SAndroid Build Coastguard Worker def test_irparser(self): 8211*da0073e9SAndroid Build Coastguard Worker graph_str = """graph(%0 : Double(5, 5)): 8212*da0073e9SAndroid Build Coastguard Worker # CHECK: aten::relu 8213*da0073e9SAndroid Build Coastguard Worker %1 : Double(5, 5) = aten::relu(%0) 8214*da0073e9SAndroid Build Coastguard Worker return (%1) 8215*da0073e9SAndroid Build Coastguard Worker """ 8216*da0073e9SAndroid Build Coastguard Worker FileCheck().run(graph_str, parse_ir(graph_str)) 8217*da0073e9SAndroid Build Coastguard Worker 8218*da0073e9SAndroid Build Coastguard Worker def test_parse_tensor_constants(self): 8219*da0073e9SAndroid Build Coastguard Worker def foo(): 8220*da0073e9SAndroid Build Coastguard Worker return torch.zeros([4, 4]) 8221*da0073e9SAndroid Build Coastguard Worker 8222*da0073e9SAndroid Build Coastguard Worker foo_s = torch.jit.script(foo) 8223*da0073e9SAndroid Build Coastguard Worker torch._C._jit_pass_constant_propagation(foo_s.graph) 8224*da0073e9SAndroid Build Coastguard Worker 8225*da0073e9SAndroid Build Coastguard Worker g = str(foo_s.graph) 8226*da0073e9SAndroid Build Coastguard Worker g_parsed = parse_ir(g, parse_tensor_constants=True) 8227*da0073e9SAndroid Build Coastguard Worker self.assertEqual(str(canonical(g_parsed)), str(canonical(foo_s.graph))) 8228*da0073e9SAndroid Build Coastguard Worker func = torch._C._create_function_from_graph("forward", g_parsed) 8229*da0073e9SAndroid Build Coastguard Worker 8230*da0073e9SAndroid Build Coastguard Worker out_parsed = func() 8231*da0073e9SAndroid Build Coastguard Worker out_func = foo() 8232*da0073e9SAndroid Build Coastguard Worker # not checking data, just dtype, size etc 8233*da0073e9SAndroid Build Coastguard Worker out_parsed[:] = 0 8234*da0073e9SAndroid Build Coastguard Worker out_func[:] = 0 8235*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_func, out_parsed) 8236*da0073e9SAndroid Build Coastguard Worker 8237*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 8238*da0073e9SAndroid Build Coastguard Worker parse_ir(g, parse_tensor_constants=False) 8239*da0073e9SAndroid Build Coastguard Worker 8240*da0073e9SAndroid Build Coastguard Worker def test_parse_nested_names(self): 8241*da0073e9SAndroid Build Coastguard Worker g_str = """ 8242*da0073e9SAndroid Build Coastguard Worker graph(%x.1 : Tensor): 8243*da0073e9SAndroid Build Coastguard Worker %3 : int = prim::Constant[value=1]() 8244*da0073e9SAndroid Build Coastguard Worker %2 : int = prim::Constant[value=2]() 8245*da0073e9SAndroid Build Coastguard Worker %hi.submod.value.5 : Tensor = aten::add(%x.1, %2, %3) 8246*da0073e9SAndroid Build Coastguard Worker return (%hi.submod.value.5) 8247*da0073e9SAndroid Build Coastguard Worker """ 8248*da0073e9SAndroid Build Coastguard Worker g = parse_ir(g_str) 8249*da0073e9SAndroid Build Coastguard Worker round_trip_g = parse_ir(str(g)) 8250*da0073e9SAndroid Build Coastguard Worker self.assertEqual(canonical(g), canonical(round_trip_g)) 8251*da0073e9SAndroid Build Coastguard Worker 8252*da0073e9SAndroid Build Coastguard Worker func1 = torch._C._create_function_from_graph("forward", g) 8253*da0073e9SAndroid Build Coastguard Worker func2 = torch._C._create_function_from_graph("forward", round_trip_g) 8254*da0073e9SAndroid Build Coastguard Worker self.assertEqual(func1(torch.ones([2])), func2(torch.ones([2]))) 8255*da0073e9SAndroid Build Coastguard Worker 8256*da0073e9SAndroid Build Coastguard Worker def test_is_after_use(self): 8257*da0073e9SAndroid Build Coastguard Worker def sorted_input_use(g): 8258*da0073e9SAndroid Build Coastguard Worker uses = list(next(g.inputs()).uses()) 8259*da0073e9SAndroid Build Coastguard Worker return sorted(uses, key=functools.cmp_to_key(type(uses[0]).isAfter)) 8260*da0073e9SAndroid Build Coastguard Worker 8261*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 8262*da0073e9SAndroid Build Coastguard Worker def foo(x): 8263*da0073e9SAndroid Build Coastguard Worker a = x + 1 8264*da0073e9SAndroid Build Coastguard Worker return (x, x, a) 8265*da0073e9SAndroid Build Coastguard Worker 8266*da0073e9SAndroid Build Coastguard Worker uses_sorted = sorted_input_use(foo.graph) 8267*da0073e9SAndroid Build Coastguard Worker # sorts last use to the end 8268*da0073e9SAndroid Build Coastguard Worker self.assertFalse(uses_sorted[0].isAfter(uses_sorted[1])) 8269*da0073e9SAndroid Build Coastguard Worker self.assertTrue(uses_sorted[0].user.kind() == "aten::add") 8270*da0073e9SAndroid Build Coastguard Worker self.assertEqual(uses_sorted[1].offset, 0) 8271*da0073e9SAndroid Build Coastguard Worker 8272*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 8273*da0073e9SAndroid Build Coastguard Worker def foo(x, cond: bool): 8274*da0073e9SAndroid Build Coastguard Worker if cond: 8275*da0073e9SAndroid Build Coastguard Worker return x + 3 8276*da0073e9SAndroid Build Coastguard Worker else: 8277*da0073e9SAndroid Build Coastguard Worker return x - 3 8278*da0073e9SAndroid Build Coastguard Worker 8279*da0073e9SAndroid Build Coastguard Worker uses_sorted = sorted_input_use(foo.graph) 8280*da0073e9SAndroid Build Coastguard Worker self.assertTrue(uses_sorted[0].user.kind() == "aten::add") 8281*da0073e9SAndroid Build Coastguard Worker self.assertTrue(uses_sorted[1].user.kind() == "aten::sub") 8282*da0073e9SAndroid Build Coastguard Worker 8283*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 8284*da0073e9SAndroid Build Coastguard Worker def foo(x, cond: bool, cond2: bool): 8285*da0073e9SAndroid Build Coastguard Worker if cond: 8286*da0073e9SAndroid Build Coastguard Worker return x + 3 8287*da0073e9SAndroid Build Coastguard Worker elif cond2 : 8288*da0073e9SAndroid Build Coastguard Worker return x - 3 8289*da0073e9SAndroid Build Coastguard Worker 8290*da0073e9SAndroid Build Coastguard Worker return x / 3 8291*da0073e9SAndroid Build Coastguard Worker 8292*da0073e9SAndroid Build Coastguard Worker graph1 = foo.graph 8293*da0073e9SAndroid Build Coastguard Worker 8294*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 8295*da0073e9SAndroid Build Coastguard Worker def foo(x, cond: bool, cond2: bool): 8296*da0073e9SAndroid Build Coastguard Worker if cond: 8297*da0073e9SAndroid Build Coastguard Worker return x + 3 8298*da0073e9SAndroid Build Coastguard Worker else: 8299*da0073e9SAndroid Build Coastguard Worker if cond2 : 8300*da0073e9SAndroid Build Coastguard Worker return x - 3 8301*da0073e9SAndroid Build Coastguard Worker return x / 3 8302*da0073e9SAndroid Build Coastguard Worker 8303*da0073e9SAndroid Build Coastguard Worker graph2 = foo.graph 8304*da0073e9SAndroid Build Coastguard Worker 8305*da0073e9SAndroid Build Coastguard Worker for graph in [graph1, graph2]: 8306*da0073e9SAndroid Build Coastguard Worker uses_sorted = sorted_input_use(graph) 8307*da0073e9SAndroid Build Coastguard Worker self.assertTrue(uses_sorted[0].user.kind() == "aten::add") 8308*da0073e9SAndroid Build Coastguard Worker self.assertTrue(uses_sorted[1].user.kind() == "aten::sub") 8309*da0073e9SAndroid Build Coastguard Worker self.assertTrue(uses_sorted[2].user.kind() == "aten::div") 8310*da0073e9SAndroid Build Coastguard Worker 8311*da0073e9SAndroid Build Coastguard Worker def test_canonicalize_control_outputs(self): 8312*da0073e9SAndroid Build Coastguard Worker def test_all_outputs(g): 8313*da0073e9SAndroid Build Coastguard Worker ifs = g.findAllNodes("prim::If") 8314*da0073e9SAndroid Build Coastguard Worker loops = g.findAllNodes("prim::Loop") 8315*da0073e9SAndroid Build Coastguard Worker 8316*da0073e9SAndroid Build Coastguard Worker def contained_blocks(node): 8317*da0073e9SAndroid Build Coastguard Worker return len(node.findAllNodes("prim::If")) * 2 + len(node.findAllNodes("prim::Loop")) 8318*da0073e9SAndroid Build Coastguard Worker for node in ifs + loops: 8319*da0073e9SAndroid Build Coastguard Worker outs = list(node.outputs()) 8320*da0073e9SAndroid Build Coastguard Worker out_name = [x.debugName() for x in outs] 8321*da0073e9SAndroid Build Coastguard Worker if len(out_name) == 0: 8322*da0073e9SAndroid Build Coastguard Worker continue 8323*da0073e9SAndroid Build Coastguard Worker fc = FileCheck() 8324*da0073e9SAndroid Build Coastguard Worker # find the last output, then all subsequent uses 8325*da0073e9SAndroid Build Coastguard Worker fc.check(out_name[-1] + " : ") 8326*da0073e9SAndroid Build Coastguard Worker # skip past node body 8327*da0073e9SAndroid Build Coastguard Worker for i in range(contained_blocks(node)): 8328*da0073e9SAndroid Build Coastguard Worker fc.check("->") 8329*da0073e9SAndroid Build Coastguard Worker if (node.kind() == "prim::If"): 8330*da0073e9SAndroid Build Coastguard Worker fc.check("->").check("->").check("\n") 8331*da0073e9SAndroid Build Coastguard Worker else: 8332*da0073e9SAndroid Build Coastguard Worker fc.check("->").check("\n") 8333*da0073e9SAndroid Build Coastguard Worker # the canonical order is the same order as the first use 8334*da0073e9SAndroid Build Coastguard Worker # appears in text 8335*da0073e9SAndroid Build Coastguard Worker for name in out_name: 8336*da0073e9SAndroid Build Coastguard Worker fc.check(name) 8337*da0073e9SAndroid Build Coastguard Worker fc.run(g) 8338*da0073e9SAndroid Build Coastguard Worker 8339*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 8340*da0073e9SAndroid Build Coastguard Worker def test(x): 8341*da0073e9SAndroid Build Coastguard Worker # type: (bool) -> Tuple[int, int] 8342*da0073e9SAndroid Build Coastguard Worker b = 2 8343*da0073e9SAndroid Build Coastguard Worker a = 1 8344*da0073e9SAndroid Build Coastguard Worker if x: 8345*da0073e9SAndroid Build Coastguard Worker a = 1 8346*da0073e9SAndroid Build Coastguard Worker b = 2 8347*da0073e9SAndroid Build Coastguard Worker x = False 8348*da0073e9SAndroid Build Coastguard Worker if x: 8349*da0073e9SAndroid Build Coastguard Worker b = a 8350*da0073e9SAndroid Build Coastguard Worker else: 8351*da0073e9SAndroid Build Coastguard Worker a = b 8352*da0073e9SAndroid Build Coastguard Worker 8353*da0073e9SAndroid Build Coastguard Worker return a, b 8354*da0073e9SAndroid Build Coastguard Worker test_all_outputs(test.graph) 8355*da0073e9SAndroid Build Coastguard Worker 8356*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 8357*da0073e9SAndroid Build Coastguard Worker def test2(x): 8358*da0073e9SAndroid Build Coastguard Worker # type: (bool) -> Tuple[int, int] 8359*da0073e9SAndroid Build Coastguard Worker b = 2 8360*da0073e9SAndroid Build Coastguard Worker a = 1 8361*da0073e9SAndroid Build Coastguard Worker if x: 8362*da0073e9SAndroid Build Coastguard Worker a = 1 8363*da0073e9SAndroid Build Coastguard Worker b = 2 8364*da0073e9SAndroid Build Coastguard Worker x = False 8365*da0073e9SAndroid Build Coastguard Worker if x: 8366*da0073e9SAndroid Build Coastguard Worker print(a) 8367*da0073e9SAndroid Build Coastguard Worker else: 8368*da0073e9SAndroid Build Coastguard Worker if x: 8369*da0073e9SAndroid Build Coastguard Worker print(b) 8370*da0073e9SAndroid Build Coastguard Worker 8371*da0073e9SAndroid Build Coastguard Worker return a, b 8372*da0073e9SAndroid Build Coastguard Worker test_all_outputs(test2.graph) 8373*da0073e9SAndroid Build Coastguard Worker 8374*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 8375*da0073e9SAndroid Build Coastguard Worker def test_loop(x, iter): 8376*da0073e9SAndroid Build Coastguard Worker # type: (bool, int) -> (None) 8377*da0073e9SAndroid Build Coastguard Worker a = 1 8378*da0073e9SAndroid Build Coastguard Worker b = 2 8379*da0073e9SAndroid Build Coastguard Worker c = 3 8380*da0073e9SAndroid Build Coastguard Worker for i in range(iter): 8381*da0073e9SAndroid Build Coastguard Worker a = 4 8382*da0073e9SAndroid Build Coastguard Worker b = 5 8383*da0073e9SAndroid Build Coastguard Worker c = 6 8384*da0073e9SAndroid Build Coastguard Worker x = True 8385*da0073e9SAndroid Build Coastguard Worker print(c) 8386*da0073e9SAndroid Build Coastguard Worker if x: 8387*da0073e9SAndroid Build Coastguard Worker print(a, b) 8388*da0073e9SAndroid Build Coastguard Worker test_all_outputs(test_loop.graph) 8389*da0073e9SAndroid Build Coastguard Worker 8390*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 8391*da0073e9SAndroid Build Coastguard Worker def loop_unused(iter): 8392*da0073e9SAndroid Build Coastguard Worker # type: (int) -> (None) 8393*da0073e9SAndroid Build Coastguard Worker a = 1 8394*da0073e9SAndroid Build Coastguard Worker b = 2 8395*da0073e9SAndroid Build Coastguard Worker c = 3 8396*da0073e9SAndroid Build Coastguard Worker for i in range(iter): 8397*da0073e9SAndroid Build Coastguard Worker c = c + 1 8398*da0073e9SAndroid Build Coastguard Worker b = b + 1 8399*da0073e9SAndroid Build Coastguard Worker a = a + 1 8400*da0073e9SAndroid Build Coastguard Worker print(a, b) 8401*da0073e9SAndroid Build Coastguard Worker print(c) 8402*da0073e9SAndroid Build Coastguard Worker 8403*da0073e9SAndroid Build Coastguard Worker # c is used, then unused should be ordered by alphabetical 8404*da0073e9SAndroid Build Coastguard Worker FileCheck().check(r"%c : int, %a : int, %b : int").run(loop_unused.graph) 8405*da0073e9SAndroid Build Coastguard Worker 8406*da0073e9SAndroid Build Coastguard Worker def test_filecheck(self): 8407*da0073e9SAndroid Build Coastguard Worker def test_check(): 8408*da0073e9SAndroid Build Coastguard Worker file = "232" 8409*da0073e9SAndroid Build Coastguard Worker FileCheck().check("2").check("3").check("2").run(file) 8410*da0073e9SAndroid Build Coastguard Worker FileCheck().check("232").run(file) 8411*da0073e9SAndroid Build Coastguard Worker 8412*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'Expected to find "22"'): 8413*da0073e9SAndroid Build Coastguard Worker FileCheck().check("22").run(file) 8414*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "CHECK: 3"): 8415*da0073e9SAndroid Build Coastguard Worker FileCheck().check("3").check("3").run(file) 8416*da0073e9SAndroid Build Coastguard Worker 8417*da0073e9SAndroid Build Coastguard Worker test_check() 8418*da0073e9SAndroid Build Coastguard Worker 8419*da0073e9SAndroid Build Coastguard Worker def test_check_count(): 8420*da0073e9SAndroid Build Coastguard Worker file = "22222" 8421*da0073e9SAndroid Build Coastguard Worker FileCheck().check_count("2", 5).run(file) 8422*da0073e9SAndroid Build Coastguard Worker FileCheck().check_count("22", 2).run(file) 8423*da0073e9SAndroid Build Coastguard Worker FileCheck().check_count("222", 1).run(file) 8424*da0073e9SAndroid Build Coastguard Worker 8425*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'Expected to not find'): 8426*da0073e9SAndroid Build Coastguard Worker FileCheck().check_count("2", 4, exactly=True).run(file) 8427*da0073e9SAndroid Build Coastguard Worker 8428*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'Expected to find "22"'): 8429*da0073e9SAndroid Build Coastguard Worker FileCheck().check_count("22", 3).run(file) 8430*da0073e9SAndroid Build Coastguard Worker 8431*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "CHECK-COUNT-6: 2"): 8432*da0073e9SAndroid Build Coastguard Worker FileCheck().check_count("2", 6).run(file) 8433*da0073e9SAndroid Build Coastguard Worker 8434*da0073e9SAndroid Build Coastguard Worker test_check_count() 8435*da0073e9SAndroid Build Coastguard Worker 8436*da0073e9SAndroid Build Coastguard Worker def test_check_same(): 8437*da0073e9SAndroid Build Coastguard Worker file = "22\n33" 8438*da0073e9SAndroid Build Coastguard Worker FileCheck().check_same("22").run(file) 8439*da0073e9SAndroid Build Coastguard Worker 8440*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Expected to not find"): 8441*da0073e9SAndroid Build Coastguard Worker FileCheck().check_same("33").run(file) 8442*da0073e9SAndroid Build Coastguard Worker 8443*da0073e9SAndroid Build Coastguard Worker file = "22 1 3" 8444*da0073e9SAndroid Build Coastguard Worker 8445*da0073e9SAndroid Build Coastguard Worker FileCheck().check("2").check_same("3").run(file) 8446*da0073e9SAndroid Build Coastguard Worker FileCheck().check_count("2", 2).check_same("3").run(file) 8447*da0073e9SAndroid Build Coastguard Worker 8448*da0073e9SAndroid Build Coastguard Worker test_check_same() 8449*da0073e9SAndroid Build Coastguard Worker 8450*da0073e9SAndroid Build Coastguard Worker def test_check_next(): 8451*da0073e9SAndroid Build Coastguard Worker file = "\n1\n2\n3" 8452*da0073e9SAndroid Build Coastguard Worker FileCheck().check("1").check_next("2").check_next("3").run(file) 8453*da0073e9SAndroid Build Coastguard Worker FileCheck().check_next("1").check_next("2").check_next("3").run(file) 8454*da0073e9SAndroid Build Coastguard Worker 8455*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Expected to find"): 8456*da0073e9SAndroid Build Coastguard Worker FileCheck().check("1").check_next("2").run("12") 8457*da0073e9SAndroid Build Coastguard Worker 8458*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Expected to not find"): 8459*da0073e9SAndroid Build Coastguard Worker FileCheck().check("1").check_next("2").run("1\n\n2") 8460*da0073e9SAndroid Build Coastguard Worker 8461*da0073e9SAndroid Build Coastguard Worker test_check_next() 8462*da0073e9SAndroid Build Coastguard Worker 8463*da0073e9SAndroid Build Coastguard Worker def test_check_dag(): 8464*da0073e9SAndroid Build Coastguard Worker fc = FileCheck().check_dag("1").check_dag("2").check_not("2") 8465*da0073e9SAndroid Build Coastguard Worker fc.run("12") 8466*da0073e9SAndroid Build Coastguard Worker fc.run("21") 8467*da0073e9SAndroid Build Coastguard Worker 8468*da0073e9SAndroid Build Coastguard Worker fc = FileCheck() 8469*da0073e9SAndroid Build Coastguard Worker fc.check_not("3").check_dag("1").check_dag("2").check_not("3") 8470*da0073e9SAndroid Build Coastguard Worker fc.run("1 3 2") 8471*da0073e9SAndroid Build Coastguard Worker fc.run("2 3 1") 8472*da0073e9SAndroid Build Coastguard Worker 8473*da0073e9SAndroid Build Coastguard Worker fc = FileCheck().check_dag("1").check_dag("2").check("3") 8474*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'Expected to find "3" but did not find it'): 8475*da0073e9SAndroid Build Coastguard Worker fc.run("1 3 2") 8476*da0073e9SAndroid Build Coastguard Worker 8477*da0073e9SAndroid Build Coastguard Worker test_check_dag() 8478*da0073e9SAndroid Build Coastguard Worker 8479*da0073e9SAndroid Build Coastguard Worker def test_check_not(): 8480*da0073e9SAndroid Build Coastguard Worker FileCheck().check_not("2").check("1").run("12") 8481*da0073e9SAndroid Build Coastguard Worker FileCheck().check("2").check_not("2").run("12") 8482*da0073e9SAndroid Build Coastguard Worker 8483*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'Expected to not find "2"'): 8484*da0073e9SAndroid Build Coastguard Worker FileCheck().check_not("2").check("1").run("21") 8485*da0073e9SAndroid Build Coastguard Worker 8486*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'Expected to not find "1"'): 8487*da0073e9SAndroid Build Coastguard Worker FileCheck().check("2").check_not("1").run("21") 8488*da0073e9SAndroid Build Coastguard Worker 8489*da0073e9SAndroid Build Coastguard Worker # checks with distinct range matchings 8490*da0073e9SAndroid Build Coastguard Worker fb = FileCheck().check_count("2", 2).check_count("2", 2).check_not("2") 8491*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'Expected to not find "2"'): 8492*da0073e9SAndroid Build Coastguard Worker fb.run("22 2 22") 8493*da0073e9SAndroid Build Coastguard Worker 8494*da0073e9SAndroid Build Coastguard Worker fb = FileCheck().check_count("2", 2).check_not("1").check_count("2", 2) 8495*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'Expected to not find "1"'): 8496*da0073e9SAndroid Build Coastguard Worker fb.run("22 1 22") 8497*da0073e9SAndroid Build Coastguard Worker 8498*da0073e9SAndroid Build Coastguard Worker def _dtype_to_jit_name(self, dtype): 8499*da0073e9SAndroid Build Coastguard Worker if dtype == torch.float32: 8500*da0073e9SAndroid Build Coastguard Worker return "Float" 8501*da0073e9SAndroid Build Coastguard Worker if dtype == torch.float64: 8502*da0073e9SAndroid Build Coastguard Worker return "Double" 8503*da0073e9SAndroid Build Coastguard Worker if dtype == torch.int64: 8504*da0073e9SAndroid Build Coastguard Worker return "Long" 8505*da0073e9SAndroid Build Coastguard Worker if dtype == torch.int32: 8506*da0073e9SAndroid Build Coastguard Worker return "Int" 8507*da0073e9SAndroid Build Coastguard Worker if dtype == torch.bool: 8508*da0073e9SAndroid Build Coastguard Worker return "Bool" 8509*da0073e9SAndroid Build Coastguard Worker raise RuntimeError('dtype not handled') 8510*da0073e9SAndroid Build Coastguard Worker 8511*da0073e9SAndroid Build Coastguard Worker def _dtype_to_expect(self, dtype, dim=0): 8512*da0073e9SAndroid Build Coastguard Worker param = ', '.join(['*'] * dim + ['device=cpu']) 8513*da0073e9SAndroid Build Coastguard Worker param = '(' + param + ')' 8514*da0073e9SAndroid Build Coastguard Worker jit_type = self._dtype_to_jit_name(dtype) 8515*da0073e9SAndroid Build Coastguard Worker if dim >= 0: 8516*da0073e9SAndroid Build Coastguard Worker return jit_type + param 8517*da0073e9SAndroid Build Coastguard Worker # special case representing wrapped number 8518*da0073e9SAndroid Build Coastguard Worker else: 8519*da0073e9SAndroid Build Coastguard Worker return jit_type.lower() 8520*da0073e9SAndroid Build Coastguard Worker 8521*da0073e9SAndroid Build Coastguard Worker 8522*da0073e9SAndroid Build Coastguard Worker def _test_dtype_op_shape(self, ops, args, input_dims=1): 8523*da0073e9SAndroid Build Coastguard Worker if input_dims < 1: 8524*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("input dims must be at least 1") 8525*da0073e9SAndroid Build Coastguard Worker dtypes = [torch.float32, torch.float64, torch.int64, torch.int32] 8526*da0073e9SAndroid Build Coastguard Worker str_args = ', '.join([str(arg) for arg in args]) + (', ' if len(args) else '') 8527*da0073e9SAndroid Build Coastguard Worker tensor_data = ('[' * input_dims) + '1, 2, 3' + (input_dims * ']') 8528*da0073e9SAndroid Build Coastguard Worker template = dedent(''' 8529*da0073e9SAndroid Build Coastguard Worker def func(): 8530*da0073e9SAndroid Build Coastguard Worker return {return_line} 8531*da0073e9SAndroid Build Coastguard Worker ''') 8532*da0073e9SAndroid Build Coastguard Worker 8533*da0073e9SAndroid Build Coastguard Worker for op in ops: 8534*da0073e9SAndroid Build Coastguard Worker for dtype in (dtypes + [None]): 8535*da0073e9SAndroid Build Coastguard Worker for tensor_type in dtypes: 8536*da0073e9SAndroid Build Coastguard Worker # a couple of ops aren't implemented for non-floating types 8537*da0073e9SAndroid Build Coastguard Worker if not tensor_type.is_floating_point or (dtype is not None and not dtype.is_floating_point): 8538*da0073e9SAndroid Build Coastguard Worker if op in ['mean', 'softmax', 'log_softmax']: 8539*da0073e9SAndroid Build Coastguard Worker continue 8540*da0073e9SAndroid Build Coastguard Worker return_line = f"torch.tensor({tensor_data}, dtype={tensor_type}).{op}({str_args}dtype={dtype})" 8541*da0073e9SAndroid Build Coastguard Worker # uncomment for debugging a failed test: 8542*da0073e9SAndroid Build Coastguard Worker # print("testing {}".format(return_line)) 8543*da0073e9SAndroid Build Coastguard Worker code = template.format(return_line=return_line) 8544*da0073e9SAndroid Build Coastguard Worker scope = {} 8545*da0073e9SAndroid Build Coastguard Worker exec(code, globals(), scope) 8546*da0073e9SAndroid Build Coastguard Worker cu = torch.jit.CompilationUnit(code) 8547*da0073e9SAndroid Build Coastguard Worker graph = cu.func.graph 8548*da0073e9SAndroid Build Coastguard Worker torch._C._jit_pass_complete_shape_analysis(graph, (), False) 8549*da0073e9SAndroid Build Coastguard Worker input_array = [1, 2, 3] 8550*da0073e9SAndroid Build Coastguard Worker for _ in range(1, input_dims): 8551*da0073e9SAndroid Build Coastguard Worker input_array = [input_array] 8552*da0073e9SAndroid Build Coastguard Worker t = torch.tensor(input_array, dtype=tensor_type) 8553*da0073e9SAndroid Build Coastguard Worker attr = getattr(t, op) 8554*da0073e9SAndroid Build Coastguard Worker kwargs = {'dtype': dtype} 8555*da0073e9SAndroid Build Coastguard Worker result = attr(*args, **kwargs) 8556*da0073e9SAndroid Build Coastguard Worker expect = self._dtype_to_expect(result.dtype, result.dim()) 8557*da0073e9SAndroid Build Coastguard Worker FileCheck().check("aten::tensor").check(expect).run(graph) 8558*da0073e9SAndroid Build Coastguard Worker 8559*da0073e9SAndroid Build Coastguard Worker def test_dtype_op_shape(self): 8560*da0073e9SAndroid Build Coastguard Worker ops = ['prod'] 8561*da0073e9SAndroid Build Coastguard Worker self._test_dtype_op_shape(ops, args=[]) 8562*da0073e9SAndroid Build Coastguard Worker self._test_dtype_op_shape(ops, args=[0, False]) 8563*da0073e9SAndroid Build Coastguard Worker self._test_dtype_op_shape(ops, args=[0, False]) 8564*da0073e9SAndroid Build Coastguard Worker self._test_dtype_op_shape(ops, args=[0, True]) 8565*da0073e9SAndroid Build Coastguard Worker 8566*da0073e9SAndroid Build Coastguard Worker def test_dtype_op_shape2(self): 8567*da0073e9SAndroid Build Coastguard Worker ops = ['cumprod', 'cumsum', 'softmax', 'log_softmax'] 8568*da0073e9SAndroid Build Coastguard Worker self._test_dtype_op_shape(ops, args=[0]) 8569*da0073e9SAndroid Build Coastguard Worker 8570*da0073e9SAndroid Build Coastguard Worker self._test_dtype_op_shape(ops, args=[1], input_dims=4) 8571*da0073e9SAndroid Build Coastguard Worker 8572*da0073e9SAndroid Build Coastguard Worker 8573*da0073e9SAndroid Build Coastguard Worker def _test_binary_op_shape(self, ops, input_dims=1): 8574*da0073e9SAndroid Build Coastguard Worker 8575*da0073e9SAndroid Build Coastguard Worker dtypes = [torch.float32, torch.float64, torch.int64, torch.int32, torch.bool] 8576*da0073e9SAndroid Build Coastguard Worker 8577*da0073e9SAndroid Build Coastguard Worker if input_dims == 0: 8578*da0073e9SAndroid Build Coastguard Worker shape = '1' 8579*da0073e9SAndroid Build Coastguard Worker else: 8580*da0073e9SAndroid Build Coastguard Worker shape = '[' + ('1,' * 4) + ']' 8581*da0073e9SAndroid Build Coastguard Worker for _ in range(1, input_dims): 8582*da0073e9SAndroid Build Coastguard Worker shape = '[' + ",".join([shape] * 4) + ']' 8583*da0073e9SAndroid Build Coastguard Worker 8584*da0073e9SAndroid Build Coastguard Worker template = dedent(''' 8585*da0073e9SAndroid Build Coastguard Worker def func(): 8586*da0073e9SAndroid Build Coastguard Worker arg1 = {} 8587*da0073e9SAndroid Build Coastguard Worker arg2 = {} 8588*da0073e9SAndroid Build Coastguard Worker return torch.{}(arg1, arg2) 8589*da0073e9SAndroid Build Coastguard Worker ''') 8590*da0073e9SAndroid Build Coastguard Worker 8591*da0073e9SAndroid Build Coastguard Worker args = [] 8592*da0073e9SAndroid Build Coastguard Worker for dtype in dtypes: 8593*da0073e9SAndroid Build Coastguard Worker args = args + [f"torch.tensor({shape}, dtype={dtype})"] 8594*da0073e9SAndroid Build Coastguard Worker args = args + [1, 1.5] 8595*da0073e9SAndroid Build Coastguard Worker 8596*da0073e9SAndroid Build Coastguard Worker def isBool(arg): 8597*da0073e9SAndroid Build Coastguard Worker return type(arg) == bool or (type(arg) == str and "torch.bool" in arg) 8598*da0073e9SAndroid Build Coastguard Worker 8599*da0073e9SAndroid Build Coastguard Worker for op in ops: 8600*da0073e9SAndroid Build Coastguard Worker for first_arg in args: 8601*da0073e9SAndroid Build Coastguard Worker for second_arg in args: 8602*da0073e9SAndroid Build Coastguard Worker # subtract not supported for bool 8603*da0073e9SAndroid Build Coastguard Worker if (op == 'sub' or op == 'div') and (isBool(first_arg) or isBool(second_arg)): 8604*da0073e9SAndroid Build Coastguard Worker continue 8605*da0073e9SAndroid Build Coastguard Worker # div is not implemented correctly for mixed-type or int params 8606*da0073e9SAndroid Build Coastguard Worker if (op == 'div' and (type(first_arg) != type(second_arg) or 8607*da0073e9SAndroid Build Coastguard Worker isinstance(first_arg, int) or 8608*da0073e9SAndroid Build Coastguard Worker (isinstance(first_arg, str) and 'int' in first_arg))): 8609*da0073e9SAndroid Build Coastguard Worker continue 8610*da0073e9SAndroid Build Coastguard Worker return_line = f"torch.{op}({first_arg}, {second_arg})" 8611*da0073e9SAndroid Build Coastguard Worker # uncomment for debugging a failed test: 8612*da0073e9SAndroid Build Coastguard Worker # print("testing {}".format(return_line)) 8613*da0073e9SAndroid Build Coastguard Worker code = template.format(first_arg, second_arg, op) 8614*da0073e9SAndroid Build Coastguard Worker scope = {} 8615*da0073e9SAndroid Build Coastguard Worker exec(code, globals(), scope) 8616*da0073e9SAndroid Build Coastguard Worker non_jit_result = scope['func']() 8617*da0073e9SAndroid Build Coastguard Worker 8618*da0073e9SAndroid Build Coastguard Worker cu = torch.jit.CompilationUnit(code) 8619*da0073e9SAndroid Build Coastguard Worker graph = cu.func.graph 8620*da0073e9SAndroid Build Coastguard Worker torch._C._jit_pass_complete_shape_analysis(graph, (), False) 8621*da0073e9SAndroid Build Coastguard Worker # use dim=-1 to represent a python/jit scalar. 8622*da0073e9SAndroid Build Coastguard Worker dim = -1 if type(first_arg) != str and type(second_arg) != str else non_jit_result.dim() 8623*da0073e9SAndroid Build Coastguard Worker dtype = non_jit_result.dtype 8624*da0073e9SAndroid Build Coastguard Worker # jit only supports int/float scalars. 8625*da0073e9SAndroid Build Coastguard Worker if dim < 0: 8626*da0073e9SAndroid Build Coastguard Worker if dtype == torch.int64: 8627*da0073e9SAndroid Build Coastguard Worker dtype = torch.int32 8628*da0073e9SAndroid Build Coastguard Worker if dtype == torch.float64: 8629*da0073e9SAndroid Build Coastguard Worker dtype = torch.float32 8630*da0073e9SAndroid Build Coastguard Worker expect = self._dtype_to_expect(dtype, dim) 8631*da0073e9SAndroid Build Coastguard Worker jit_output = next(graph.outputs()) 8632*da0073e9SAndroid Build Coastguard Worker 8633*da0073e9SAndroid Build Coastguard Worker check = FileCheck() 8634*da0073e9SAndroid Build Coastguard Worker check.check(expect).run(str(jit_output)) 8635*da0073e9SAndroid Build Coastguard Worker 8636*da0073e9SAndroid Build Coastguard Worker def test_binary_op_shape(self): 8637*da0073e9SAndroid Build Coastguard Worker self._test_binary_op_shape(['mul', 'div', 'add', 'sub'], 0) 8638*da0073e9SAndroid Build Coastguard Worker self._test_binary_op_shape(['mul', 'div', 'add', 'sub'], 3) 8639*da0073e9SAndroid Build Coastguard Worker 8640*da0073e9SAndroid Build Coastguard Worker def test_no_dtype_shape(self): 8641*da0073e9SAndroid Build Coastguard Worker 8642*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 8643*da0073e9SAndroid Build Coastguard Worker def foo(x): 8644*da0073e9SAndroid Build Coastguard Worker scalar_number = x.item() 8645*da0073e9SAndroid Build Coastguard Worker return x.add(scalar_number) 8646*da0073e9SAndroid Build Coastguard Worker 8647*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 8648*da0073e9SAndroid Build Coastguard Worker def foo2(x): 8649*da0073e9SAndroid Build Coastguard Worker scalar_number = x.item() 8650*da0073e9SAndroid Build Coastguard Worker return torch.tensor(1).add(scalar_number) 8651*da0073e9SAndroid Build Coastguard Worker 8652*da0073e9SAndroid Build Coastguard Worker t = torch.tensor(5) 8653*da0073e9SAndroid Build Coastguard Worker g = foo.graph_for(t) 8654*da0073e9SAndroid Build Coastguard Worker type = next(g.outputs()) 8655*da0073e9SAndroid Build Coastguard Worker self.assertTrue(type.type() == torch._C.TensorType.get()) 8656*da0073e9SAndroid Build Coastguard Worker g2 = foo2.graph_for(t) 8657*da0073e9SAndroid Build Coastguard Worker type = next(g.outputs()) 8658*da0073e9SAndroid Build Coastguard Worker self.assertTrue(type.type() == torch._C.TensorType.get()) 8659*da0073e9SAndroid Build Coastguard Worker 8660*da0073e9SAndroid Build Coastguard Worker 8661*da0073e9SAndroid Build Coastguard Worker def test_filecheck_parse(self): 8662*da0073e9SAndroid Build Coastguard Worker def test_check(): 8663*da0073e9SAndroid Build Coastguard Worker file = """ 8664*da0073e9SAndroid Build Coastguard Worker # CHECK: 2 8665*da0073e9SAndroid Build Coastguard Worker # CHECK: 3 8666*da0073e9SAndroid Build Coastguard Worker # CHECK: 2 8667*da0073e9SAndroid Build Coastguard Worker 232 8668*da0073e9SAndroid Build Coastguard Worker """ 8669*da0073e9SAndroid Build Coastguard Worker FileCheck().run(checks_file=file, test_file=file) 8670*da0073e9SAndroid Build Coastguard Worker file = """ 8671*da0073e9SAndroid Build Coastguard Worker # CHECK: 232 8672*da0073e9SAndroid Build Coastguard Worker 232 8673*da0073e9SAndroid Build Coastguard Worker """ 8674*da0073e9SAndroid Build Coastguard Worker FileCheck().run(file, "232") 8675*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'Expected to find "232"'): 8676*da0073e9SAndroid Build Coastguard Worker FileCheck().run(file, "22") 8677*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'Expected to find "22"'): 8678*da0073e9SAndroid Build Coastguard Worker FileCheck().run("# CHECK: 22", "23") 8679*da0073e9SAndroid Build Coastguard Worker test_check() 8680*da0073e9SAndroid Build Coastguard Worker 8681*da0073e9SAndroid Build Coastguard Worker def test_check_count(): 8682*da0073e9SAndroid Build Coastguard Worker file = "22222" 8683*da0073e9SAndroid Build Coastguard Worker FileCheck().run("# CHECK-COUNT-5: 2", file) 8684*da0073e9SAndroid Build Coastguard Worker FileCheck().run("# CHECK-COUNT-EXACTLY-5: 2", file) 8685*da0073e9SAndroid Build Coastguard Worker FileCheck().run("# CHECK-COUNT-2: 22", file) 8686*da0073e9SAndroid Build Coastguard Worker FileCheck().run("# CHECK-COUNT-1: 222", file) 8687*da0073e9SAndroid Build Coastguard Worker 8688*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'Expected to not find'): 8689*da0073e9SAndroid Build Coastguard Worker FileCheck().run("# CHECK-COUNT-EXACTLY-2: 2", file) 8690*da0073e9SAndroid Build Coastguard Worker test_check_count() 8691*da0073e9SAndroid Build Coastguard Worker 8692*da0073e9SAndroid Build Coastguard Worker def test_check_same(): 8693*da0073e9SAndroid Build Coastguard Worker file = "22\n33" 8694*da0073e9SAndroid Build Coastguard Worker FileCheck().run("# CHECK-SAME: 22", file) 8695*da0073e9SAndroid Build Coastguard Worker 8696*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Expected to not find"): 8697*da0073e9SAndroid Build Coastguard Worker FileCheck().run("# CHECK-SAME: 33", file) 8698*da0073e9SAndroid Build Coastguard Worker 8699*da0073e9SAndroid Build Coastguard Worker file = "22 1 3" 8700*da0073e9SAndroid Build Coastguard Worker 8701*da0073e9SAndroid Build Coastguard Worker FileCheck().run("# CHECK: 2\n # CHECK-SAME: 3", file) 8702*da0073e9SAndroid Build Coastguard Worker FileCheck().run("# CHECK-COUNT-2: 2\n # CHECK-SAME: 3", file) 8703*da0073e9SAndroid Build Coastguard Worker test_check_same() 8704*da0073e9SAndroid Build Coastguard Worker 8705*da0073e9SAndroid Build Coastguard Worker def test_bad_input(): 8706*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Check for bad input"): 8707*da0073e9SAndroid Build Coastguard Worker FileCheck().run("", "1") 8708*da0073e9SAndroid Build Coastguard Worker 8709*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Could not parse check"): 8710*da0073e9SAndroid Build Coastguard Worker FileCheck().run("# CHECK1", "") 8711*da0073e9SAndroid Build Coastguard Worker 8712*da0073e9SAndroid Build Coastguard Worker test_bad_input() 8713*da0073e9SAndroid Build Coastguard Worker 8714*da0073e9SAndroid Build Coastguard Worker def test_script_module_call_noscript(self): 8715*da0073e9SAndroid Build Coastguard Worker class M(torch.jit.ScriptModule): 8716*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 8717*da0073e9SAndroid Build Coastguard Worker super().__init__() 8718*da0073e9SAndroid Build Coastguard Worker self.value = 1 8719*da0073e9SAndroid Build Coastguard Worker 8720*da0073e9SAndroid Build Coastguard Worker @torch.jit.ignore 8721*da0073e9SAndroid Build Coastguard Worker def foo(self): 8722*da0073e9SAndroid Build Coastguard Worker return torch.ones(2, 2) + self.value 8723*da0073e9SAndroid Build Coastguard Worker 8724*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 8725*da0073e9SAndroid Build Coastguard Worker def forward(self, input): 8726*da0073e9SAndroid Build Coastguard Worker return input + self.foo() 8727*da0073e9SAndroid Build Coastguard Worker 8728*da0073e9SAndroid Build Coastguard Worker with torch.jit.optimized_execution(False): 8729*da0073e9SAndroid Build Coastguard Worker m = M() 8730*da0073e9SAndroid Build Coastguard Worker input = torch.randn(2, 2) 8731*da0073e9SAndroid Build Coastguard Worker o = m(input) 8732*da0073e9SAndroid Build Coastguard Worker self.assertEqual(o, input + torch.ones(2, 2) + 1) 8733*da0073e9SAndroid Build Coastguard Worker # check that we can change python attributes 8734*da0073e9SAndroid Build Coastguard Worker # and that those changes are picked up in script methods 8735*da0073e9SAndroid Build Coastguard Worker m.value = 2 8736*da0073e9SAndroid Build Coastguard Worker o = m(input) 8737*da0073e9SAndroid Build Coastguard Worker self.assertEqual(o, input + torch.ones(2, 2) + 2) 8738*da0073e9SAndroid Build Coastguard Worker 8739*da0073e9SAndroid Build Coastguard Worker def test_script_module_nochange_submodule(self): 8740*da0073e9SAndroid Build Coastguard Worker class M(torch.jit.ScriptModule): 8741*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 8742*da0073e9SAndroid Build Coastguard Worker super().__init__() 8743*da0073e9SAndroid Build Coastguard Worker self.sub = nn.Linear(5, 5) 8744*da0073e9SAndroid Build Coastguard Worker 8745*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 8746*da0073e9SAndroid Build Coastguard Worker def forward(self, input): 8747*da0073e9SAndroid Build Coastguard Worker return self.sub(input) 8748*da0073e9SAndroid Build Coastguard Worker with torch.jit.optimized_execution(False): 8749*da0073e9SAndroid Build Coastguard Worker m = M() 8750*da0073e9SAndroid Build Coastguard Worker input = torch.randn(1, 5, 5) 8751*da0073e9SAndroid Build Coastguard Worker o = m(input) 8752*da0073e9SAndroid Build Coastguard Worker self.assertEqual(o, m.sub(input)) 8753*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Cannot re-assign"): 8754*da0073e9SAndroid Build Coastguard Worker m.sub = nn.Linear(5, 5) 8755*da0073e9SAndroid Build Coastguard Worker 8756*da0073e9SAndroid Build Coastguard Worker def test_module_apis(self): 8757*da0073e9SAndroid Build Coastguard Worker class Sub(torch.nn.Module): 8758*da0073e9SAndroid Build Coastguard Worker def forward(self, thing): 8759*da0073e9SAndroid Build Coastguard Worker return thing - 2 8760*da0073e9SAndroid Build Coastguard Worker 8761*da0073e9SAndroid Build Coastguard Worker class Double(torch.nn.Module): 8762*da0073e9SAndroid Build Coastguard Worker def forward(self, thing): 8763*da0073e9SAndroid Build Coastguard Worker return thing * 2 8764*da0073e9SAndroid Build Coastguard Worker 8765*da0073e9SAndroid Build Coastguard Worker class MyMod(torch.nn.Module): 8766*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 8767*da0073e9SAndroid Build Coastguard Worker super().__init__() 8768*da0073e9SAndroid Build Coastguard Worker self.mod = (Sub()) 8769*da0073e9SAndroid Build Coastguard Worker self.mod2 = (Sub()) 8770*da0073e9SAndroid Build Coastguard Worker self.mod3 = nn.Sequential(nn.Sequential(Sub())) 8771*da0073e9SAndroid Build Coastguard Worker self.mod4 = nn.Sequential(Sub(), Double()) 8772*da0073e9SAndroid Build Coastguard Worker 8773*da0073e9SAndroid Build Coastguard Worker @torch.jit.export 8774*da0073e9SAndroid Build Coastguard Worker def method(self, x, x1, y, y1): 8775*da0073e9SAndroid Build Coastguard Worker mod_names = "" 8776*da0073e9SAndroid Build Coastguard Worker for name, mod in self.named_modules(): 8777*da0073e9SAndroid Build Coastguard Worker mod_names = mod_names + " " + name 8778*da0073e9SAndroid Build Coastguard Worker x = mod(x) 8779*da0073e9SAndroid Build Coastguard Worker 8780*da0073e9SAndroid Build Coastguard Worker children_names = "" 8781*da0073e9SAndroid Build Coastguard Worker for name, mod in self.named_children(): 8782*da0073e9SAndroid Build Coastguard Worker children_names = children_names + " " + name 8783*da0073e9SAndroid Build Coastguard Worker x1 = mod(x1) 8784*da0073e9SAndroid Build Coastguard Worker 8785*da0073e9SAndroid Build Coastguard Worker for mod in self.modules(): 8786*da0073e9SAndroid Build Coastguard Worker y = mod(y) 8787*da0073e9SAndroid Build Coastguard Worker 8788*da0073e9SAndroid Build Coastguard Worker for mod in self.children(): 8789*da0073e9SAndroid Build Coastguard Worker y1 = mod(y1) 8790*da0073e9SAndroid Build Coastguard Worker 8791*da0073e9SAndroid Build Coastguard Worker return mod_names, children_names, x, x1, y, y1 8792*da0073e9SAndroid Build Coastguard Worker 8793*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 8794*da0073e9SAndroid Build Coastguard Worker return x + 2 8795*da0073e9SAndroid Build Coastguard Worker 8796*da0073e9SAndroid Build Coastguard Worker mod = torch.jit.script(MyMod()) 8797*da0073e9SAndroid Build Coastguard Worker inps = tuple([torch.tensor(i) for i in range(1, 5)]) 8798*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mod.method(*inps), MyMod().method(*inps)) 8799*da0073e9SAndroid Build Coastguard Worker 8800*da0073e9SAndroid Build Coastguard Worker def test_script_module_const(self): 8801*da0073e9SAndroid Build Coastguard Worker class M(torch.jit.ScriptModule): 8802*da0073e9SAndroid Build Coastguard Worker 8803*da0073e9SAndroid Build Coastguard Worker __constants__ = ['b', 'i', 'c', 's'] 8804*da0073e9SAndroid Build Coastguard Worker 8805*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 8806*da0073e9SAndroid Build Coastguard Worker super().__init__() 8807*da0073e9SAndroid Build Coastguard Worker self.b = False 8808*da0073e9SAndroid Build Coastguard Worker self.i = 1 8809*da0073e9SAndroid Build Coastguard Worker self.c = 3.5 8810*da0073e9SAndroid Build Coastguard Worker self.s = ["hello"] 8811*da0073e9SAndroid Build Coastguard Worker 8812*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 8813*da0073e9SAndroid Build Coastguard Worker def forward(self): 8814*da0073e9SAndroid Build Coastguard Worker return self.b, self.i, self.c 8815*da0073e9SAndroid Build Coastguard Worker 8816*da0073e9SAndroid Build Coastguard Worker with torch.jit.optimized_execution(False): 8817*da0073e9SAndroid Build Coastguard Worker m = M() 8818*da0073e9SAndroid Build Coastguard Worker o0, o1, o2 = m() 8819*da0073e9SAndroid Build Coastguard Worker self.assertEqual(o0, 0) 8820*da0073e9SAndroid Build Coastguard Worker self.assertEqual(o1, 1) 8821*da0073e9SAndroid Build Coastguard Worker self.assertEqual(o2, 3.5) 8822*da0073e9SAndroid Build Coastguard Worker 8823*da0073e9SAndroid Build Coastguard Worker def test_script_module_fail_exist(self): 8824*da0073e9SAndroid Build Coastguard Worker class M(torch.jit.ScriptModule): 8825*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 8826*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 8827*da0073e9SAndroid Build Coastguard Worker return x + self.whatisgoingon 8828*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Module 'M' has no attribute"): 8829*da0073e9SAndroid Build Coastguard Worker M() 8830*da0073e9SAndroid Build Coastguard Worker 8831*da0073e9SAndroid Build Coastguard Worker @unittest.skip("[module dedupe] currently NoneType refinement on optional attributes doesn't work.") 8832*da0073e9SAndroid Build Coastguard Worker def test_script_module_none_exist_fail(self): 8833*da0073e9SAndroid Build Coastguard Worker class M(torch.jit.ScriptModule): 8834*da0073e9SAndroid Build Coastguard Worker def __init__(self, my_optional): 8835*da0073e9SAndroid Build Coastguard Worker super().__init__() 8836*da0073e9SAndroid Build Coastguard Worker self.my_optional = my_optional 8837*da0073e9SAndroid Build Coastguard Worker 8838*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 8839*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 8840*da0073e9SAndroid Build Coastguard Worker if self.my_optional is not None: 8841*da0073e9SAndroid Build Coastguard Worker return torch.neg(x) + self.my_optional 8842*da0073e9SAndroid Build Coastguard Worker return torch.neg(x) 8843*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "has no attribute 'my_optional'"): 8844*da0073e9SAndroid Build Coastguard Worker x = torch.rand(3, 4) 8845*da0073e9SAndroid Build Coastguard Worker fb = M(None) 8846*da0073e9SAndroid Build Coastguard Worker fb(x) 8847*da0073e9SAndroid Build Coastguard Worker 8848*da0073e9SAndroid Build Coastguard Worker def test_script_module_invalid_consts(self): 8849*da0073e9SAndroid Build Coastguard Worker class Foo(torch.jit.ScriptModule): 8850*da0073e9SAndroid Build Coastguard Worker __constants__ = ['invalid'] 8851*da0073e9SAndroid Build Coastguard Worker 8852*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 8853*da0073e9SAndroid Build Coastguard Worker super().__init__() 8854*da0073e9SAndroid Build Coastguard Worker self.invalid = [nn.Linear(3, 4)] 8855*da0073e9SAndroid Build Coastguard Worker 8856*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 8857*da0073e9SAndroid Build Coastguard Worker TypeError, 8858*da0073e9SAndroid Build Coastguard Worker "Linear' object in attribute 'Foo.invalid' is not a valid constant"): 8859*da0073e9SAndroid Build Coastguard Worker Foo() 8860*da0073e9SAndroid Build Coastguard Worker 8861*da0073e9SAndroid Build Coastguard Worker class Foo2(torch.jit.ScriptModule): 8862*da0073e9SAndroid Build Coastguard Worker __constants__ = ['invalid'] 8863*da0073e9SAndroid Build Coastguard Worker 8864*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 8865*da0073e9SAndroid Build Coastguard Worker super().__init__() 8866*da0073e9SAndroid Build Coastguard Worker self.invalid = int 8867*da0073e9SAndroid Build Coastguard Worker 8868*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(TypeError, "not a valid constant"): 8869*da0073e9SAndroid Build Coastguard Worker Foo2() 8870*da0073e9SAndroid Build Coastguard Worker 8871*da0073e9SAndroid Build Coastguard Worker class Foo3(torch.jit.ScriptModule): 8872*da0073e9SAndroid Build Coastguard Worker __constants__ = ['invalid'] 8873*da0073e9SAndroid Build Coastguard Worker 8874*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 8875*da0073e9SAndroid Build Coastguard Worker super().__init__() 8876*da0073e9SAndroid Build Coastguard Worker self.invalid = (3, 4, {}) 8877*da0073e9SAndroid Build Coastguard Worker 8878*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(TypeError, "not a valid constant"): 8879*da0073e9SAndroid Build Coastguard Worker Foo3() 8880*da0073e9SAndroid Build Coastguard Worker 8881*da0073e9SAndroid Build Coastguard Worker class Foo4(torch.jit.ScriptModule): 8882*da0073e9SAndroid Build Coastguard Worker __constants__ = ['invalid'] 8883*da0073e9SAndroid Build Coastguard Worker 8884*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 8885*da0073e9SAndroid Build Coastguard Worker super().__init__() 8886*da0073e9SAndroid Build Coastguard Worker self.invalid = np.int64(5) 8887*da0073e9SAndroid Build Coastguard Worker 8888*da0073e9SAndroid Build Coastguard Worker # verify that we capture human understandable class name 8889*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(TypeError, "numpy.int64"): 8890*da0073e9SAndroid Build Coastguard Worker Foo4() 8891*da0073e9SAndroid Build Coastguard Worker 8892*da0073e9SAndroid Build Coastguard Worker def test_script_module_param_buffer_mutation(self): 8893*da0073e9SAndroid Build Coastguard Worker # TODO: add param mutation test case after JIT support it 8894*da0073e9SAndroid Build Coastguard Worker class ModuleBufferMutate(torch.jit.ScriptModule): 8895*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 8896*da0073e9SAndroid Build Coastguard Worker super().__init__() 8897*da0073e9SAndroid Build Coastguard Worker self.running_var = nn.Buffer(torch.tensor(0, dtype=torch.long)) 8898*da0073e9SAndroid Build Coastguard Worker 8899*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 8900*da0073e9SAndroid Build Coastguard Worker def forward(self): 8901*da0073e9SAndroid Build Coastguard Worker if self.training: 8902*da0073e9SAndroid Build Coastguard Worker self.running_var += 1 8903*da0073e9SAndroid Build Coastguard Worker return self.running_var 8904*da0073e9SAndroid Build Coastguard Worker 8905*da0073e9SAndroid Build Coastguard Worker with torch.jit.optimized_execution(False): 8906*da0073e9SAndroid Build Coastguard Worker m = ModuleBufferMutate() 8907*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m(), 1) 8908*da0073e9SAndroid Build Coastguard Worker m.eval() 8909*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m(), 1) 8910*da0073e9SAndroid Build Coastguard Worker 8911*da0073e9SAndroid Build Coastguard Worker def test_script_module_for(self): 8912*da0073e9SAndroid Build Coastguard Worker class M(torch.jit.ScriptModule): 8913*da0073e9SAndroid Build Coastguard Worker __constants__ = ['b'] 8914*da0073e9SAndroid Build Coastguard Worker 8915*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 8916*da0073e9SAndroid Build Coastguard Worker super().__init__() 8917*da0073e9SAndroid Build Coastguard Worker self.b = [1, 2, 3, 4] 8918*da0073e9SAndroid Build Coastguard Worker 8919*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 8920*da0073e9SAndroid Build Coastguard Worker def forward(self): 8921*da0073e9SAndroid Build Coastguard Worker sum = 0 8922*da0073e9SAndroid Build Coastguard Worker for i in self.b: 8923*da0073e9SAndroid Build Coastguard Worker sum += i 8924*da0073e9SAndroid Build Coastguard Worker return sum 8925*da0073e9SAndroid Build Coastguard Worker 8926*da0073e9SAndroid Build Coastguard Worker with torch.jit.optimized_execution(False): 8927*da0073e9SAndroid Build Coastguard Worker m = M() 8928*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m(), 10) 8929*da0073e9SAndroid Build Coastguard Worker 8930*da0073e9SAndroid Build Coastguard Worker def test_override_magic(self): 8931*da0073e9SAndroid Build Coastguard Worker class OverrideMagic(nn.Module): 8932*da0073e9SAndroid Build Coastguard Worker @torch.jit.export 8933*da0073e9SAndroid Build Coastguard Worker def __len__(self): 8934*da0073e9SAndroid Build Coastguard Worker return 10 8935*da0073e9SAndroid Build Coastguard Worker 8936*da0073e9SAndroid Build Coastguard Worker mod = OverrideMagic() 8937*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(mod), len(torch.jit.script(mod))) 8938*da0073e9SAndroid Build Coastguard Worker 8939*da0073e9SAndroid Build Coastguard Worker class OverrideMagicSeq(nn.Sequential): 8940*da0073e9SAndroid Build Coastguard Worker @torch.jit.export 8941*da0073e9SAndroid Build Coastguard Worker def __len__(self): 8942*da0073e9SAndroid Build Coastguard Worker return 10 8943*da0073e9SAndroid Build Coastguard Worker 8944*da0073e9SAndroid Build Coastguard Worker mod = OverrideMagicSeq() 8945*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(mod), len(torch.jit.script(mod))) 8946*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.jit.script(mod)) 8947*da0073e9SAndroid Build Coastguard Worker 8948*da0073e9SAndroid Build Coastguard Worker def test_script_module_for2(self): 8949*da0073e9SAndroid Build Coastguard Worker class Sub(torch.jit.ScriptModule): 8950*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 8951*da0073e9SAndroid Build Coastguard Worker super().__init__() 8952*da0073e9SAndroid Build Coastguard Worker self.weight = nn.Parameter(torch.randn(2)) 8953*da0073e9SAndroid Build Coastguard Worker 8954*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 8955*da0073e9SAndroid Build Coastguard Worker def forward(self, thing): 8956*da0073e9SAndroid Build Coastguard Worker return self.weight + thing 8957*da0073e9SAndroid Build Coastguard Worker 8958*da0073e9SAndroid Build Coastguard Worker class M(torch.jit.ScriptModule): 8959*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 8960*da0073e9SAndroid Build Coastguard Worker super().__init__() 8961*da0073e9SAndroid Build Coastguard Worker self.mods = nn.ModuleList([Sub() for i in range(10)]) 8962*da0073e9SAndroid Build Coastguard Worker 8963*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 8964*da0073e9SAndroid Build Coastguard Worker def forward(self, v): 8965*da0073e9SAndroid Build Coastguard Worker for m in self.mods: 8966*da0073e9SAndroid Build Coastguard Worker v = m(v) 8967*da0073e9SAndroid Build Coastguard Worker return v 8968*da0073e9SAndroid Build Coastguard Worker 8969*da0073e9SAndroid Build Coastguard Worker with torch.jit.optimized_execution(False): 8970*da0073e9SAndroid Build Coastguard Worker i = torch.empty(2) 8971*da0073e9SAndroid Build Coastguard Worker m = M() 8972*da0073e9SAndroid Build Coastguard Worker o = m(i) 8973*da0073e9SAndroid Build Coastguard Worker v = i 8974*da0073e9SAndroid Build Coastguard Worker for sub in m.mods: 8975*da0073e9SAndroid Build Coastguard Worker v = sub(v) 8976*da0073e9SAndroid Build Coastguard Worker self.assertEqual(o, v) 8977*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(Exception, "object is not iterable"): 8978*da0073e9SAndroid Build Coastguard Worker print(list(m)) 8979*da0073e9SAndroid Build Coastguard Worker 8980*da0073e9SAndroid Build Coastguard Worker def test_attr_qscheme_script(self): 8981*da0073e9SAndroid Build Coastguard Worker class Foo(torch.nn.Module): 8982*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 8983*da0073e9SAndroid Build Coastguard Worker super().__init__() 8984*da0073e9SAndroid Build Coastguard Worker self.qscheme = torch.per_tensor_affine 8985*da0073e9SAndroid Build Coastguard Worker 8986*da0073e9SAndroid Build Coastguard Worker def forward(self): 8987*da0073e9SAndroid Build Coastguard Worker if self.qscheme == torch.per_tensor_symmetric: 8988*da0073e9SAndroid Build Coastguard Worker return 3 8989*da0073e9SAndroid Build Coastguard Worker else: 8990*da0073e9SAndroid Build Coastguard Worker return 4 8991*da0073e9SAndroid Build Coastguard Worker 8992*da0073e9SAndroid Build Coastguard Worker f = Foo() 8993*da0073e9SAndroid Build Coastguard Worker scripted = torch.jit.script(f) 8994*da0073e9SAndroid Build Coastguard Worker self.assertEqual(f(), scripted()) 8995*da0073e9SAndroid Build Coastguard Worker 8996*da0073e9SAndroid Build Coastguard Worker def test_script_module_const_submodule_fail(self): 8997*da0073e9SAndroid Build Coastguard Worker class Sub(torch.jit.ScriptModule): 8998*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 8999*da0073e9SAndroid Build Coastguard Worker super().__init__() 9000*da0073e9SAndroid Build Coastguard Worker self.weight = nn.Parameter(torch.randn(2)) 9001*da0073e9SAndroid Build Coastguard Worker 9002*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 9003*da0073e9SAndroid Build Coastguard Worker def forward(self, thing): 9004*da0073e9SAndroid Build Coastguard Worker return self.weight + thing 9005*da0073e9SAndroid Build Coastguard Worker 9006*da0073e9SAndroid Build Coastguard Worker class M(torch.jit.ScriptModule): 9007*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 9008*da0073e9SAndroid Build Coastguard Worker super().__init__() 9009*da0073e9SAndroid Build Coastguard Worker self.mods = [Sub() for _ in range(10)] 9010*da0073e9SAndroid Build Coastguard Worker 9011*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 9012*da0073e9SAndroid Build Coastguard Worker def forward(self): 9013*da0073e9SAndroid Build Coastguard Worker for _ in self.mods: 9014*da0073e9SAndroid Build Coastguard Worker print(1) 9015*da0073e9SAndroid Build Coastguard Worker return 4 9016*da0073e9SAndroid Build Coastguard Worker 9017*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "has no attribute 'mods'"): 9018*da0073e9SAndroid Build Coastguard Worker M() 9019*da0073e9SAndroid Build Coastguard Worker 9020*da0073e9SAndroid Build Coastguard Worker class DerivedStateModule(torch.jit.ScriptModule): 9021*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 9022*da0073e9SAndroid Build Coastguard Worker super(TestScript.DerivedStateModule, self).__init__() 9023*da0073e9SAndroid Build Coastguard Worker self.param = torch.nn.Parameter(torch.ones(3, 4, dtype=torch.float)) 9024*da0073e9SAndroid Build Coastguard Worker self.derived = nn.Buffer(torch.neg(self.param).detach().clone()) 9025*da0073e9SAndroid Build Coastguard Worker 9026*da0073e9SAndroid Build Coastguard Worker # This is a flag so we can test that the pack method was called 9027*da0073e9SAndroid Build Coastguard Worker self.pack_called = nn.Buffer(torch.zeros(1, dtype=torch.long)) 9028*da0073e9SAndroid Build Coastguard Worker # This is a flag so we can test that the unpack method was called 9029*da0073e9SAndroid Build Coastguard Worker self.unpack_called = nn.Buffer(torch.zeros(1, dtype=torch.long)) 9030*da0073e9SAndroid Build Coastguard Worker 9031*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 9032*da0073e9SAndroid Build Coastguard Worker def _pack(self): 9033*da0073e9SAndroid Build Coastguard Worker self.pack_called.set_(torch.ones(1, dtype=torch.long)) 9034*da0073e9SAndroid Build Coastguard Worker self.derived.set_(torch.rand(1).detach()) 9035*da0073e9SAndroid Build Coastguard Worker 9036*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 9037*da0073e9SAndroid Build Coastguard Worker def _unpack(self): 9038*da0073e9SAndroid Build Coastguard Worker self.unpack_called.set_(torch.ones(1, dtype=torch.long)) 9039*da0073e9SAndroid Build Coastguard Worker self.derived.set_(torch.neg(self.param).detach()) 9040*da0073e9SAndroid Build Coastguard Worker 9041*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 9042*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 9043*da0073e9SAndroid Build Coastguard Worker return x + self.derived 9044*da0073e9SAndroid Build Coastguard Worker 9045*da0073e9SAndroid Build Coastguard Worker def test_pack_unpack_state(self): 9046*da0073e9SAndroid Build Coastguard Worker sm = TestScript.DerivedStateModule() 9047*da0073e9SAndroid Build Coastguard Worker x = torch.rand(3, 4) 9048*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(sm(x), x + torch.neg(torch.ones(3, 4, dtype=torch.float))) 9049*da0073e9SAndroid Build Coastguard Worker 9050*da0073e9SAndroid Build Coastguard Worker # Test save path 9051*da0073e9SAndroid Build Coastguard Worker self.assertFalse(sm.pack_called.item()) 9052*da0073e9SAndroid Build Coastguard Worker self.assertFalse(sm.unpack_called.item()) 9053*da0073e9SAndroid Build Coastguard Worker imported = self.getExportImportCopyWithPacking(sm) 9054*da0073e9SAndroid Build Coastguard Worker # ensure pack was called before serialization 9055*da0073e9SAndroid Build Coastguard Worker self.assertTrue(sm.pack_called.item()) 9056*da0073e9SAndroid Build Coastguard Worker # ensure unpack was called after serialization so as to leave the module in an initialized state 9057*da0073e9SAndroid Build Coastguard Worker self.assertTrue(sm.unpack_called.item()) 9058*da0073e9SAndroid Build Coastguard Worker 9059*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(sm.derived, torch.neg(sm.param)) 9060*da0073e9SAndroid Build Coastguard Worker 9061*da0073e9SAndroid Build Coastguard Worker # Test load paths 9062*da0073e9SAndroid Build Coastguard Worker self.assertTrue(imported.unpack_called.item()) 9063*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(imported(x), x + torch.neg(torch.ones(3, 4, dtype=torch.float))) 9064*da0073e9SAndroid Build Coastguard Worker 9065*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_MKL, "PyTorch is built without MKL support") 9066*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(True, "Skipping while landing PR stack") 9067*da0073e9SAndroid Build Coastguard Worker def test_torch_functional(self): 9068*da0073e9SAndroid Build Coastguard Worker def stft(input, n_fft): 9069*da0073e9SAndroid Build Coastguard Worker # type: (Tensor, int) -> Tensor 9070*da0073e9SAndroid Build Coastguard Worker return torch.stft(input, n_fft, return_complex=True) 9071*da0073e9SAndroid Build Coastguard Worker 9072*da0073e9SAndroid Build Coastguard Worker inps = (torch.randn(10), 7) 9073*da0073e9SAndroid Build Coastguard Worker self.assertEqual(stft(*inps), torch.jit.script(stft)(*inps)) 9074*da0073e9SAndroid Build Coastguard Worker 9075*da0073e9SAndroid Build Coastguard Worker def istft(input, n_fft): 9076*da0073e9SAndroid Build Coastguard Worker # type: (Tensor, int) -> Tensor 9077*da0073e9SAndroid Build Coastguard Worker return torch.istft(input, n_fft) 9078*da0073e9SAndroid Build Coastguard Worker 9079*da0073e9SAndroid Build Coastguard Worker inps2 = (stft(*inps), inps[1]) 9080*da0073e9SAndroid Build Coastguard Worker self.assertEqual(istft(*inps2), torch.jit.script(istft)(*inps2)) 9081*da0073e9SAndroid Build Coastguard Worker 9082*da0073e9SAndroid Build Coastguard Worker def lu_unpack(x): 9083*da0073e9SAndroid Build Coastguard Worker A_LU, pivots = torch.linalg.lu_factor(x) 9084*da0073e9SAndroid Build Coastguard Worker return torch.lu_unpack(A_LU, pivots) 9085*da0073e9SAndroid Build Coastguard Worker 9086*da0073e9SAndroid Build Coastguard Worker for shape in ((3, 3), (5, 3, 3), (7, 3, 5, 5), (7, 5, 3, 3, 3)): 9087*da0073e9SAndroid Build Coastguard Worker a = torch.randn(*shape) 9088*da0073e9SAndroid Build Coastguard Worker self.checkScript(lu_unpack, (a,)) 9089*da0073e9SAndroid Build Coastguard Worker 9090*da0073e9SAndroid Build Coastguard Worker def cdist_fn(): 9091*da0073e9SAndroid Build Coastguard Worker a = torch.tensor([[0.9041, 0.0196], [-0.3108, -2.4423], [-0.4821, 1.059]]) 9092*da0073e9SAndroid Build Coastguard Worker b = torch.tensor([[-2.1763, -0.4713], [-0.6986, 1.3702]]) 9093*da0073e9SAndroid Build Coastguard Worker return torch.cdist(a, b, compute_mode="use_mm_for_euclid_dist") 9094*da0073e9SAndroid Build Coastguard Worker 9095*da0073e9SAndroid Build Coastguard Worker self.checkScript(cdist_fn, ()) 9096*da0073e9SAndroid Build Coastguard Worker 9097*da0073e9SAndroid Build Coastguard Worker def norm(): 9098*da0073e9SAndroid Build Coastguard Worker c = torch.tensor([[1, 2, 3], [-1, 1, 4]], dtype=torch.float) 9099*da0073e9SAndroid Build Coastguard Worker return torch.norm(c, p="fro"), torch.norm(c, p="nuc"), torch.norm(c), torch.norm(c, p=.5) 9100*da0073e9SAndroid Build Coastguard Worker 9101*da0073e9SAndroid Build Coastguard Worker self.checkScript(norm, ()) 9102*da0073e9SAndroid Build Coastguard Worker 9103*da0073e9SAndroid Build Coastguard Worker def torch_unique(dim: Optional[int]): 9104*da0073e9SAndroid Build Coastguard Worker ten = torch.unique(torch.tensor([[1, 3], [2, 3]], dtype=torch.long)) 9105*da0073e9SAndroid Build Coastguard Worker a = torch.unique(ten, dim=dim) 9106*da0073e9SAndroid Build Coastguard Worker b = torch.unique(ten, return_counts=True, dim=dim) 9107*da0073e9SAndroid Build Coastguard Worker c = torch.unique(ten, return_inverse=True, dim=dim) 9108*da0073e9SAndroid Build Coastguard Worker d = torch.unique(ten, return_counts=True, return_inverse=True, dim=dim) 9109*da0073e9SAndroid Build Coastguard Worker return a, b, c, d 9110*da0073e9SAndroid Build Coastguard Worker 9111*da0073e9SAndroid Build Coastguard Worker self.checkScript(torch_unique, (None,)) 9112*da0073e9SAndroid Build Coastguard Worker self.checkScript(torch_unique, (0,)) 9113*da0073e9SAndroid Build Coastguard Worker 9114*da0073e9SAndroid Build Coastguard Worker def torch_unique_consecutive(dim: Optional[int]): 9115*da0073e9SAndroid Build Coastguard Worker ten = torch.unique(torch.tensor([[1, 3], [3, 2], [3, 2], [2, 3]], dtype=torch.long)) 9116*da0073e9SAndroid Build Coastguard Worker a = torch.unique_consecutive(ten, dim=dim) 9117*da0073e9SAndroid Build Coastguard Worker b = torch.unique_consecutive(ten, return_counts=True, dim=dim) 9118*da0073e9SAndroid Build Coastguard Worker c = torch.unique_consecutive(ten, return_inverse=True, dim=dim) 9119*da0073e9SAndroid Build Coastguard Worker d = torch.unique_consecutive(ten, return_counts=True, return_inverse=True, dim=dim) 9120*da0073e9SAndroid Build Coastguard Worker return a, b, c, d 9121*da0073e9SAndroid Build Coastguard Worker 9122*da0073e9SAndroid Build Coastguard Worker self.checkScript(torch_unique_consecutive, (None,)) 9123*da0073e9SAndroid Build Coastguard Worker self.checkScript(torch_unique_consecutive, (0,)) 9124*da0073e9SAndroid Build Coastguard Worker 9125*da0073e9SAndroid Build Coastguard Worker def test_torch_functional_tensordot_int(self): 9126*da0073e9SAndroid Build Coastguard Worker def tensordot_dims_int(a: torch.Tensor, b: torch.Tensor, dims: int): 9127*da0073e9SAndroid Build Coastguard Worker return torch.tensordot(a, b, dims=dims) 9128*da0073e9SAndroid Build Coastguard Worker 9129*da0073e9SAndroid Build Coastguard Worker a = torch.arange(120.).reshape(2, 3, 4, 5) 9130*da0073e9SAndroid Build Coastguard Worker b = torch.arange(840.).reshape(4, 5, 6, 7) 9131*da0073e9SAndroid Build Coastguard Worker dims = 2 9132*da0073e9SAndroid Build Coastguard Worker self.checkScript(tensordot_dims_int, (a, b, dims)) 9133*da0073e9SAndroid Build Coastguard Worker 9134*da0073e9SAndroid Build Coastguard Worker for dims in [-1, 5]: 9135*da0073e9SAndroid Build Coastguard Worker try: 9136*da0073e9SAndroid Build Coastguard Worker tensordot_dims_int(a, b, dims) 9137*da0073e9SAndroid Build Coastguard Worker except RuntimeError as error: 9138*da0073e9SAndroid Build Coastguard Worker if dims < 0: 9139*da0073e9SAndroid Build Coastguard Worker self.assertEqual(str(error), "tensordot expects dims >= 0, but got dims=" + str(dims)) 9140*da0073e9SAndroid Build Coastguard Worker if dims > min(a.dim(), b.dim()): 9141*da0073e9SAndroid Build Coastguard Worker self.assertEqual(str(error), "tensordot expects dims < ndim_a or ndim_b, but got dims=" + str(dims)) 9142*da0073e9SAndroid Build Coastguard Worker 9143*da0073e9SAndroid Build Coastguard Worker def test_torch_functional_tensordot_tensor(self): 9144*da0073e9SAndroid Build Coastguard Worker def tensordot_dims_tensor(a: torch.Tensor, b: torch.Tensor, dims: torch.Tensor): 9145*da0073e9SAndroid Build Coastguard Worker return torch.tensordot(a, b, dims=dims) 9146*da0073e9SAndroid Build Coastguard Worker 9147*da0073e9SAndroid Build Coastguard Worker a = torch.arange(120.).reshape(2, 3, 4, 5) 9148*da0073e9SAndroid Build Coastguard Worker b = torch.arange(840.).reshape(4, 5, 6, 7) 9149*da0073e9SAndroid Build Coastguard Worker dims = torch.tensor([2]) 9150*da0073e9SAndroid Build Coastguard Worker self.checkScript(tensordot_dims_tensor, (a, b, dims)) 9151*da0073e9SAndroid Build Coastguard Worker 9152*da0073e9SAndroid Build Coastguard Worker a = torch.arange(60.).reshape(3, 4, 5) 9153*da0073e9SAndroid Build Coastguard Worker b = torch.arange(24.).reshape(4, 3, 2) 9154*da0073e9SAndroid Build Coastguard Worker dims = torch.tensor([[1, 0], [0, 1]], dtype=torch.long) 9155*da0073e9SAndroid Build Coastguard Worker self.checkScript(tensordot_dims_tensor, (a, b, dims)) 9156*da0073e9SAndroid Build Coastguard Worker 9157*da0073e9SAndroid Build Coastguard Worker def test_torch_functional_tensordot_list(self): 9158*da0073e9SAndroid Build Coastguard Worker def tensordot_dims_list(a: torch.Tensor, b: torch.Tensor, dims: List[List[int]]): 9159*da0073e9SAndroid Build Coastguard Worker return torch.tensordot(a, b, dims=dims) 9160*da0073e9SAndroid Build Coastguard Worker 9161*da0073e9SAndroid Build Coastguard Worker a = torch.arange(60.).reshape(3, 4, 5) 9162*da0073e9SAndroid Build Coastguard Worker b = torch.arange(24.).reshape(4, 3, 2) 9163*da0073e9SAndroid Build Coastguard Worker dims = [[1, 0], [0, 1]] 9164*da0073e9SAndroid Build Coastguard Worker self.checkScript(tensordot_dims_list, (a, b, dims)) 9165*da0073e9SAndroid Build Coastguard Worker 9166*da0073e9SAndroid Build Coastguard Worker def test_torch_functional_tensordot_tuple(self): 9167*da0073e9SAndroid Build Coastguard Worker def tensordot_dims_tuple(a: torch.Tensor, b: torch.Tensor, dims: Tuple[List[int], List[int]]): 9168*da0073e9SAndroid Build Coastguard Worker return torch.tensordot(a, b, dims=dims) 9169*da0073e9SAndroid Build Coastguard Worker 9170*da0073e9SAndroid Build Coastguard Worker a = torch.arange(60.).reshape(3, 4, 5) 9171*da0073e9SAndroid Build Coastguard Worker b = torch.arange(24.).reshape(4, 3, 2) 9172*da0073e9SAndroid Build Coastguard Worker dims = ([1, 0], [0, 1]) 9173*da0073e9SAndroid Build Coastguard Worker self.checkScript(tensordot_dims_tuple, (a, b, dims)) 9174*da0073e9SAndroid Build Coastguard Worker 9175*da0073e9SAndroid Build Coastguard Worker def test_missing_getstate(self): 9176*da0073e9SAndroid Build Coastguard Worker class Foo(torch.nn.Module): 9177*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 9178*da0073e9SAndroid Build Coastguard Worker super().__init__() 9179*da0073e9SAndroid Build Coastguard Worker self.x = 1 9180*da0073e9SAndroid Build Coastguard Worker 9181*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 9182*da0073e9SAndroid Build Coastguard Worker return x * self.x 9183*da0073e9SAndroid Build Coastguard Worker 9184*da0073e9SAndroid Build Coastguard Worker @torch.jit.export 9185*da0073e9SAndroid Build Coastguard Worker def __setstate__(self, state): 9186*da0073e9SAndroid Build Coastguard Worker self.x = state[0] 9187*da0073e9SAndroid Build Coastguard Worker self.training = state[1] 9188*da0073e9SAndroid Build Coastguard Worker 9189*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "getstate"): 9190*da0073e9SAndroid Build Coastguard Worker scripted = torch.jit.script(Foo()) 9191*da0073e9SAndroid Build Coastguard Worker 9192*da0073e9SAndroid Build Coastguard Worker def test_inlining_cleanup(self): 9193*da0073e9SAndroid Build Coastguard Worker def foo(x): 9194*da0073e9SAndroid Build Coastguard Worker return F.linear(x, x) 9195*da0073e9SAndroid Build Coastguard Worker 9196*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 9197*da0073e9SAndroid Build Coastguard Worker def fee(x): 9198*da0073e9SAndroid Build Coastguard Worker return foo(x) 9199*da0073e9SAndroid Build Coastguard Worker 9200*da0073e9SAndroid Build Coastguard Worker # inlining optimizations should have cleaned up linear if statement 9201*da0073e9SAndroid Build Coastguard Worker self.run_pass("inline", fee.graph) 9202*da0073e9SAndroid Build Coastguard Worker FileCheck().check_not("prim::If").run(fee.graph) 9203*da0073e9SAndroid Build Coastguard Worker 9204*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("TorchDynamo fails with unknown reason") 9205*da0073e9SAndroid Build Coastguard Worker def test_pack_unpack_nested(self): 9206*da0073e9SAndroid Build Coastguard Worker class SubSubMod(torch.jit.ScriptModule): 9207*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 9208*da0073e9SAndroid Build Coastguard Worker super().__init__() 9209*da0073e9SAndroid Build Coastguard Worker self.buf = nn.Buffer(torch.ones(3, 4) * 3) 9210*da0073e9SAndroid Build Coastguard Worker 9211*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 9212*da0073e9SAndroid Build Coastguard Worker def _pack(self): 9213*da0073e9SAndroid Build Coastguard Worker self.buf.set_(torch.zeros(1)) 9214*da0073e9SAndroid Build Coastguard Worker 9215*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 9216*da0073e9SAndroid Build Coastguard Worker def _unpack(self): 9217*da0073e9SAndroid Build Coastguard Worker self.buf.set_(torch.ones(3, 4) * 3) 9218*da0073e9SAndroid Build Coastguard Worker 9219*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 9220*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 9221*da0073e9SAndroid Build Coastguard Worker return x + self.buf 9222*da0073e9SAndroid Build Coastguard Worker 9223*da0073e9SAndroid Build Coastguard Worker class SubMod(torch.jit.ScriptModule): 9224*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 9225*da0073e9SAndroid Build Coastguard Worker super().__init__() 9226*da0073e9SAndroid Build Coastguard Worker self.buf = nn.Buffer(torch.ones(3, 4) * 2) 9227*da0073e9SAndroid Build Coastguard Worker self.ssm = SubSubMod() 9228*da0073e9SAndroid Build Coastguard Worker 9229*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 9230*da0073e9SAndroid Build Coastguard Worker def _pack(self): 9231*da0073e9SAndroid Build Coastguard Worker self.buf.set_(torch.zeros(1)) 9232*da0073e9SAndroid Build Coastguard Worker 9233*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 9234*da0073e9SAndroid Build Coastguard Worker def _unpack(self): 9235*da0073e9SAndroid Build Coastguard Worker self.buf.set_(torch.ones(3, 4) * 2) 9236*da0073e9SAndroid Build Coastguard Worker 9237*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 9238*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 9239*da0073e9SAndroid Build Coastguard Worker return self.ssm(x + self.buf) 9240*da0073e9SAndroid Build Coastguard Worker 9241*da0073e9SAndroid Build Coastguard Worker class Mod(torch.jit.ScriptModule): 9242*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 9243*da0073e9SAndroid Build Coastguard Worker super().__init__() 9244*da0073e9SAndroid Build Coastguard Worker self.submod = SubMod() 9245*da0073e9SAndroid Build Coastguard Worker self.buf = nn.Buffer(torch.ones(3, 4) * 1) 9246*da0073e9SAndroid Build Coastguard Worker 9247*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 9248*da0073e9SAndroid Build Coastguard Worker def _pack(self): 9249*da0073e9SAndroid Build Coastguard Worker self.buf.set_(torch.zeros(1)) 9250*da0073e9SAndroid Build Coastguard Worker 9251*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 9252*da0073e9SAndroid Build Coastguard Worker def _unpack(self): 9253*da0073e9SAndroid Build Coastguard Worker self.buf.set_(torch.ones(3, 4)) 9254*da0073e9SAndroid Build Coastguard Worker 9255*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 9256*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 9257*da0073e9SAndroid Build Coastguard Worker return self.submod(x + self.buf) 9258*da0073e9SAndroid Build Coastguard Worker 9259*da0073e9SAndroid Build Coastguard Worker m = Mod() 9260*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(m(torch.zeros(3, 4)), torch.ones(3, 4) * 6) 9261*da0073e9SAndroid Build Coastguard Worker m.apply(lambda s: s._pack()) 9262*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(m(torch.zeros(3, 4)), torch.zeros(3, 4)) 9263*da0073e9SAndroid Build Coastguard Worker m.apply(lambda s: s._unpack()) 9264*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(m(torch.zeros(3, 4)), torch.ones(3, 4) * 6) 9265*da0073e9SAndroid Build Coastguard Worker 9266*da0073e9SAndroid Build Coastguard Worker def test_torch_any(self): 9267*da0073e9SAndroid Build Coastguard Worker def fn(x): 9268*da0073e9SAndroid Build Coastguard Worker return torch.any(x) 9269*da0073e9SAndroid Build Coastguard Worker 9270*da0073e9SAndroid Build Coastguard Worker def fn1(x, dim: int): 9271*da0073e9SAndroid Build Coastguard Worker return torch.any(x, dim) 9272*da0073e9SAndroid Build Coastguard Worker 9273*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, (torch.randn(3, 4), )) 9274*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, (torch.empty(3), )) 9275*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, (torch.empty(1), )) 9276*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, (torch.ones(3, 4),)) 9277*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, (torch.zeros(5, 7, 1),)) 9278*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn1, (torch.empty(3, 4), -2)) 9279*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn1, (torch.randn(3, 8), 1)) 9280*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn1, (torch.zeros(3, 6, 9), -3)) 9281*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn1, (torch.empty(5), 0)) 9282*da0073e9SAndroid Build Coastguard Worker 9283*da0073e9SAndroid Build Coastguard Worker def test_any(self): 9284*da0073e9SAndroid Build Coastguard Worker def fn(x: List[int]): 9285*da0073e9SAndroid Build Coastguard Worker return any(x) 9286*da0073e9SAndroid Build Coastguard Worker 9287*da0073e9SAndroid Build Coastguard Worker def fn1(x: List[float]): 9288*da0073e9SAndroid Build Coastguard Worker return any(x) 9289*da0073e9SAndroid Build Coastguard Worker 9290*da0073e9SAndroid Build Coastguard Worker def fn2(x: List[bool]): 9291*da0073e9SAndroid Build Coastguard Worker return any(x) 9292*da0073e9SAndroid Build Coastguard Worker 9293*da0073e9SAndroid Build Coastguard Worker def fn3(x: List[str]): 9294*da0073e9SAndroid Build Coastguard Worker return any(x) 9295*da0073e9SAndroid Build Coastguard Worker 9296*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, ([0, 0, 0, 0], )) 9297*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, ([0, 3, 0], )) 9298*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, ([], )) 9299*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn1, ([1.0, 2.0, 3.0], )) 9300*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn1, ([0.0, 0.0, 0.0], )) 9301*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn1, ([0, 0, 0], )) 9302*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn1, ([], )) 9303*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn2, ([True, False, False], )) 9304*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn2, ([False, False, False], )) 9305*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn2, ([True, True, True, True], )) 9306*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn2, ([], )) 9307*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn3, (["", "", ""], )) 9308*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn3, (["", "", "", "-1"], )) 9309*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn3, ([], )) 9310*da0073e9SAndroid Build Coastguard Worker 9311*da0073e9SAndroid Build Coastguard Worker def test_script_module_not_tuple(self): 9312*da0073e9SAndroid Build Coastguard Worker class M(torch.jit.ScriptModule): 9313*da0073e9SAndroid Build Coastguard Worker __constants__ = ['mods'] 9314*da0073e9SAndroid Build Coastguard Worker 9315*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 9316*da0073e9SAndroid Build Coastguard Worker super().__init__() 9317*da0073e9SAndroid Build Coastguard Worker self.mods = 1 9318*da0073e9SAndroid Build Coastguard Worker 9319*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 9320*da0073e9SAndroid Build Coastguard Worker def forward(self, v): 9321*da0073e9SAndroid Build Coastguard Worker for m in self.mods: 9322*da0073e9SAndroid Build Coastguard Worker print(m) 9323*da0073e9SAndroid Build Coastguard Worker return v 9324*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "'int' object is not iterable"): 9325*da0073e9SAndroid Build Coastguard Worker M() 9326*da0073e9SAndroid Build Coastguard Worker 9327*da0073e9SAndroid Build Coastguard Worker def test_attr_module_constants(self): 9328*da0073e9SAndroid Build Coastguard Worker class M2(torch.jit.ScriptModule): 9329*da0073e9SAndroid Build Coastguard Worker def __init__(self, mod_list): 9330*da0073e9SAndroid Build Coastguard Worker super().__init__() 9331*da0073e9SAndroid Build Coastguard Worker self.mods = mod_list 9332*da0073e9SAndroid Build Coastguard Worker 9333*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 9334*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 9335*da0073e9SAndroid Build Coastguard Worker return self.mods.forward(x) 9336*da0073e9SAndroid Build Coastguard Worker 9337*da0073e9SAndroid Build Coastguard Worker with torch.jit.optimized_execution(False): 9338*da0073e9SAndroid Build Coastguard Worker m = M2(nn.Sequential(nn.ReLU())) 9339*da0073e9SAndroid Build Coastguard Worker self.assertExportImportModule(m, (torch.randn(2, 2),)) 9340*da0073e9SAndroid Build Coastguard Worker 9341*da0073e9SAndroid Build Coastguard Worker def test_script_sequential_for(self): 9342*da0073e9SAndroid Build Coastguard Worker class Sub(torch.jit.ScriptModule): 9343*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 9344*da0073e9SAndroid Build Coastguard Worker super().__init__() 9345*da0073e9SAndroid Build Coastguard Worker self.weight = nn.Parameter(torch.randn(2)) 9346*da0073e9SAndroid Build Coastguard Worker 9347*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 9348*da0073e9SAndroid Build Coastguard Worker def forward(self, thing): 9349*da0073e9SAndroid Build Coastguard Worker return self.weight + thing 9350*da0073e9SAndroid Build Coastguard Worker 9351*da0073e9SAndroid Build Coastguard Worker class M(torch.jit.ScriptModule): 9352*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 9353*da0073e9SAndroid Build Coastguard Worker super().__init__() 9354*da0073e9SAndroid Build Coastguard Worker self.mods = nn.Sequential(Sub(), Sub(), Sub()) 9355*da0073e9SAndroid Build Coastguard Worker 9356*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 9357*da0073e9SAndroid Build Coastguard Worker def forward(self, v): 9358*da0073e9SAndroid Build Coastguard Worker for m in self.mods: 9359*da0073e9SAndroid Build Coastguard Worker v = m(v) 9360*da0073e9SAndroid Build Coastguard Worker return v 9361*da0073e9SAndroid Build Coastguard Worker 9362*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 9363*da0073e9SAndroid Build Coastguard Worker def forward2(self, v): 9364*da0073e9SAndroid Build Coastguard Worker return self.mods(v) 9365*da0073e9SAndroid Build Coastguard Worker 9366*da0073e9SAndroid Build Coastguard Worker with torch.jit.optimized_execution(False): 9367*da0073e9SAndroid Build Coastguard Worker i = torch.empty(2) 9368*da0073e9SAndroid Build Coastguard Worker m = M() 9369*da0073e9SAndroid Build Coastguard Worker o = m(i) 9370*da0073e9SAndroid Build Coastguard Worker v = i 9371*da0073e9SAndroid Build Coastguard Worker for sub in m.mods._modules.values(): 9372*da0073e9SAndroid Build Coastguard Worker v = sub(v) 9373*da0073e9SAndroid Build Coastguard Worker self.assertEqual(o, v) 9374*da0073e9SAndroid Build Coastguard Worker 9375*da0073e9SAndroid Build Coastguard Worker o2 = m.forward2(i) 9376*da0073e9SAndroid Build Coastguard Worker self.assertEqual(o2, v) 9377*da0073e9SAndroid Build Coastguard Worker 9378*da0073e9SAndroid Build Coastguard Worker def test_script_sequential_sliced_iteration(self): 9379*da0073e9SAndroid Build Coastguard Worker class seq_mod(nn.Module): 9380*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 9381*da0073e9SAndroid Build Coastguard Worker super().__init__() 9382*da0073e9SAndroid Build Coastguard Worker self.layers = [nn.ReLU(), nn.ReLU(), nn.ReLU()] 9383*da0073e9SAndroid Build Coastguard Worker self.layers = nn.Sequential(*self.layers) 9384*da0073e9SAndroid Build Coastguard Worker 9385*da0073e9SAndroid Build Coastguard Worker def forward(self, input): 9386*da0073e9SAndroid Build Coastguard Worker x = self.layers[0].forward(input) 9387*da0073e9SAndroid Build Coastguard Worker for layer in self.layers[1:3]: 9388*da0073e9SAndroid Build Coastguard Worker x = layer.forward(x) 9389*da0073e9SAndroid Build Coastguard Worker for layer in self.layers[2:]: 9390*da0073e9SAndroid Build Coastguard Worker x = layer.forward(x) 9391*da0073e9SAndroid Build Coastguard Worker return x 9392*da0073e9SAndroid Build Coastguard Worker 9393*da0073e9SAndroid Build Coastguard Worker seq = seq_mod() 9394*da0073e9SAndroid Build Coastguard Worker self.checkModule(seq, [torch.tensor([-2, 1, -1, 2])]) 9395*da0073e9SAndroid Build Coastguard Worker 9396*da0073e9SAndroid Build Coastguard Worker def test_script_sequential_orderdict(self): 9397*da0073e9SAndroid Build Coastguard Worker class M(torch.jit.ScriptModule): 9398*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 9399*da0073e9SAndroid Build Coastguard Worker super().__init__() 9400*da0073e9SAndroid Build Coastguard Worker self.mods = nn.Sequential(OrderedDict([ 9401*da0073e9SAndroid Build Coastguard Worker ("conv", nn.Conv2d(1, 20, 5)), 9402*da0073e9SAndroid Build Coastguard Worker ("relu", nn.ReLU()) 9403*da0073e9SAndroid Build Coastguard Worker ])) 9404*da0073e9SAndroid Build Coastguard Worker 9405*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 9406*da0073e9SAndroid Build Coastguard Worker def forward(self, input): 9407*da0073e9SAndroid Build Coastguard Worker return self.mods(input) 9408*da0073e9SAndroid Build Coastguard Worker 9409*da0073e9SAndroid Build Coastguard Worker m = M() 9410*da0073e9SAndroid Build Coastguard Worker self.assertTrue('mods.conv.weight' in m.state_dict().keys()) 9411*da0073e9SAndroid Build Coastguard Worker 9412*da0073e9SAndroid Build Coastguard Worker def test_script_sequential_multi_output_fail(self): 9413*da0073e9SAndroid Build Coastguard Worker class Sub(torch.jit.ScriptModule): 9414*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 9415*da0073e9SAndroid Build Coastguard Worker super().__init__() 9416*da0073e9SAndroid Build Coastguard Worker self.weight = nn.Parameter(torch.randn(2)) 9417*da0073e9SAndroid Build Coastguard Worker 9418*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 9419*da0073e9SAndroid Build Coastguard Worker def forward(self, thing): 9420*da0073e9SAndroid Build Coastguard Worker return self.weight + thing 9421*da0073e9SAndroid Build Coastguard Worker 9422*da0073e9SAndroid Build Coastguard Worker class ReturnMulti(torch.jit.ScriptModule): 9423*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 9424*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 9425*da0073e9SAndroid Build Coastguard Worker return x, x, x 9426*da0073e9SAndroid Build Coastguard Worker 9427*da0073e9SAndroid Build Coastguard Worker class HaveSequential(torch.jit.ScriptModule): 9428*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 9429*da0073e9SAndroid Build Coastguard Worker super().__init__() 9430*da0073e9SAndroid Build Coastguard Worker self.someseq = nn.Sequential( 9431*da0073e9SAndroid Build Coastguard Worker Sub(), 9432*da0073e9SAndroid Build Coastguard Worker ReturnMulti(), 9433*da0073e9SAndroid Build Coastguard Worker Sub() 9434*da0073e9SAndroid Build Coastguard Worker ) 9435*da0073e9SAndroid Build Coastguard Worker 9436*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 9437*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 9438*da0073e9SAndroid Build Coastguard Worker return self.someseq(x) 9439*da0073e9SAndroid Build Coastguard Worker 9440*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "(Tensor, Tensor, Tensor)"): 9441*da0073e9SAndroid Build Coastguard Worker with torch.jit.optimized_execution(False): 9442*da0073e9SAndroid Build Coastguard Worker hs = HaveSequential() 9443*da0073e9SAndroid Build Coastguard Worker i = torch.empty(2) 9444*da0073e9SAndroid Build Coastguard Worker hs(i) 9445*da0073e9SAndroid Build Coastguard Worker 9446*da0073e9SAndroid Build Coastguard Worker @_tmp_donotuse_dont_inline_everything 9447*da0073e9SAndroid Build Coastguard Worker def test_script_sequential_in_mod_list(self): 9448*da0073e9SAndroid Build Coastguard Worker class Sub(torch.jit.ScriptModule): 9449*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 9450*da0073e9SAndroid Build Coastguard Worker super().__init__() 9451*da0073e9SAndroid Build Coastguard Worker self.weight = nn.Parameter(torch.randn(2)) 9452*da0073e9SAndroid Build Coastguard Worker 9453*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 9454*da0073e9SAndroid Build Coastguard Worker def forward(self, thing): 9455*da0073e9SAndroid Build Coastguard Worker return self.weight + thing 9456*da0073e9SAndroid Build Coastguard Worker 9457*da0073e9SAndroid Build Coastguard Worker class M(torch.jit.ScriptModule): 9458*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 9459*da0073e9SAndroid Build Coastguard Worker super().__init__() 9460*da0073e9SAndroid Build Coastguard Worker self.mods = nn.ModuleList([Sub(), nn.Sequential(Sub(), nn.Sequential(Sub(), Sub()), Sub())]) 9461*da0073e9SAndroid Build Coastguard Worker 9462*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 9463*da0073e9SAndroid Build Coastguard Worker def forward(self, v): 9464*da0073e9SAndroid Build Coastguard Worker for mod in self.mods: 9465*da0073e9SAndroid Build Coastguard Worker v = mod(v) 9466*da0073e9SAndroid Build Coastguard Worker return v 9467*da0073e9SAndroid Build Coastguard Worker 9468*da0073e9SAndroid Build Coastguard Worker m = M() 9469*da0073e9SAndroid Build Coastguard Worker graph = str(m.graph) 9470*da0073e9SAndroid Build Coastguard Worker self.assertTrue(graph.count("prim::CallMethod") == 2) 9471*da0073e9SAndroid Build Coastguard Worker self.assertTrue("python" not in graph) 9472*da0073e9SAndroid Build Coastguard Worker 9473*da0073e9SAndroid Build Coastguard Worker @_tmp_donotuse_dont_inline_everything 9474*da0073e9SAndroid Build Coastguard Worker def test_script_nested_mod_list(self): 9475*da0073e9SAndroid Build Coastguard Worker class Sub(torch.jit.ScriptModule): 9476*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 9477*da0073e9SAndroid Build Coastguard Worker super().__init__() 9478*da0073e9SAndroid Build Coastguard Worker self.weight = nn.Parameter(torch.randn(2)) 9479*da0073e9SAndroid Build Coastguard Worker 9480*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 9481*da0073e9SAndroid Build Coastguard Worker def forward(self, thing): 9482*da0073e9SAndroid Build Coastguard Worker return self.weight + thing 9483*da0073e9SAndroid Build Coastguard Worker 9484*da0073e9SAndroid Build Coastguard Worker class M(torch.jit.ScriptModule): 9485*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 9486*da0073e9SAndroid Build Coastguard Worker super().__init__() 9487*da0073e9SAndroid Build Coastguard Worker self.mods = nn.ModuleList([nn.ModuleList([Sub()]), nn.Sequential(Sub()), nn.ModuleList([Sub(), Sub()])]) 9488*da0073e9SAndroid Build Coastguard Worker 9489*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 9490*da0073e9SAndroid Build Coastguard Worker def forward(self, v): 9491*da0073e9SAndroid Build Coastguard Worker for mod in self.mods: 9492*da0073e9SAndroid Build Coastguard Worker for m in mod: 9493*da0073e9SAndroid Build Coastguard Worker v = m(v) 9494*da0073e9SAndroid Build Coastguard Worker return v 9495*da0073e9SAndroid Build Coastguard Worker 9496*da0073e9SAndroid Build Coastguard Worker m = M() 9497*da0073e9SAndroid Build Coastguard Worker graph = str(m.graph) 9498*da0073e9SAndroid Build Coastguard Worker self.assertTrue(graph.count("prim::CallMethod") == 4) 9499*da0073e9SAndroid Build Coastguard Worker self.assertTrue("python" not in graph) 9500*da0073e9SAndroid Build Coastguard Worker 9501*da0073e9SAndroid Build Coastguard Worker def test_constant_as_attr(self): 9502*da0073e9SAndroid Build Coastguard Worker class M(torch.jit.ScriptModule): 9503*da0073e9SAndroid Build Coastguard Worker __constants__ = ['dim'] 9504*da0073e9SAndroid Build Coastguard Worker 9505*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 9506*da0073e9SAndroid Build Coastguard Worker super().__init__() 9507*da0073e9SAndroid Build Coastguard Worker self.dim = 1 9508*da0073e9SAndroid Build Coastguard Worker 9509*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 9510*da0073e9SAndroid Build Coastguard Worker def forward(self, v): 9511*da0073e9SAndroid Build Coastguard Worker return torch.cat([v, v, v], dim=self.dim) 9512*da0073e9SAndroid Build Coastguard Worker v = torch.zeros(1, 1) 9513*da0073e9SAndroid Build Coastguard Worker with torch.jit.optimized_execution(False): 9514*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.cat([v, v, v], dim=1), M()(v)) 9515*da0073e9SAndroid Build Coastguard Worker 9516*da0073e9SAndroid Build Coastguard Worker class StarTestSumStarred(torch.nn.Module): 9517*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 9518*da0073e9SAndroid Build Coastguard Worker super(TestScript.StarTestSumStarred, self).__init__() 9519*da0073e9SAndroid Build Coastguard Worker 9520*da0073e9SAndroid Build Coastguard Worker def forward(self, *inputs): 9521*da0073e9SAndroid Build Coastguard Worker output = inputs[0] 9522*da0073e9SAndroid Build Coastguard Worker for i in range(1, len(inputs)): 9523*da0073e9SAndroid Build Coastguard Worker output += inputs[i] 9524*da0073e9SAndroid Build Coastguard Worker return output 9525*da0073e9SAndroid Build Coastguard Worker 9526*da0073e9SAndroid Build Coastguard Worker class StarTestReturnThree(torch.nn.Module): 9527*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 9528*da0073e9SAndroid Build Coastguard Worker super(TestScript.StarTestReturnThree, self).__init__() 9529*da0073e9SAndroid Build Coastguard Worker 9530*da0073e9SAndroid Build Coastguard Worker def forward(self, rep): 9531*da0073e9SAndroid Build Coastguard Worker return rep, rep, rep 9532*da0073e9SAndroid Build Coastguard Worker 9533*da0073e9SAndroid Build Coastguard Worker def test_script_star_expr(self): 9534*da0073e9SAndroid Build Coastguard Worker 9535*da0073e9SAndroid Build Coastguard Worker class M2(torch.jit.ScriptModule): 9536*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 9537*da0073e9SAndroid Build Coastguard Worker super().__init__() 9538*da0073e9SAndroid Build Coastguard Worker self.m = torch.jit.trace(TestScript.StarTestSumStarred(), 9539*da0073e9SAndroid Build Coastguard Worker (torch.ones(4, 3), torch.ones(4, 3), torch.ones(4, 3))) 9540*da0073e9SAndroid Build Coastguard Worker self.g = torch.jit.trace(TestScript.StarTestReturnThree(), torch.ones(4, 3)) 9541*da0073e9SAndroid Build Coastguard Worker 9542*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 9543*da0073e9SAndroid Build Coastguard Worker def forward(self, rep): 9544*da0073e9SAndroid Build Coastguard Worker tup = self.g(rep) 9545*da0073e9SAndroid Build Coastguard Worker return self.m(*tup) 9546*da0073e9SAndroid Build Coastguard Worker 9547*da0073e9SAndroid Build Coastguard Worker m = M2() 9548*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m(torch.zeros(4, 3)), 3 * torch.zeros(4, 3)) 9549*da0073e9SAndroid Build Coastguard Worker 9550*da0073e9SAndroid Build Coastguard Worker def test_script_star_expr_string(self): 9551*da0073e9SAndroid Build Coastguard Worker class M2(torch.jit.ScriptModule): 9552*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 9553*da0073e9SAndroid Build Coastguard Worker super().__init__() 9554*da0073e9SAndroid Build Coastguard Worker self.m = torch.jit.trace(TestScript.StarTestSumStarred(), 9555*da0073e9SAndroid Build Coastguard Worker (torch.ones(4, 3), torch.ones(4, 3), torch.ones(4, 3))) 9556*da0073e9SAndroid Build Coastguard Worker self.g = torch.jit.trace(TestScript.StarTestReturnThree(), torch.ones(4, 3)) 9557*da0073e9SAndroid Build Coastguard Worker 9558*da0073e9SAndroid Build Coastguard Worker self.define(''' 9559*da0073e9SAndroid Build Coastguard Worker def forward(self, rep): 9560*da0073e9SAndroid Build Coastguard Worker tup = self.g(rep) 9561*da0073e9SAndroid Build Coastguard Worker return self.m(*tup) 9562*da0073e9SAndroid Build Coastguard Worker ''') 9563*da0073e9SAndroid Build Coastguard Worker 9564*da0073e9SAndroid Build Coastguard Worker m = M2() 9565*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m(torch.zeros(4, 3)), 3 * torch.zeros(4, 3)) 9566*da0073e9SAndroid Build Coastguard Worker 9567*da0073e9SAndroid Build Coastguard Worker class StarTestSumAndReturnThree(torch.nn.Module): 9568*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 9569*da0073e9SAndroid Build Coastguard Worker super(TestScript.StarTestSumAndReturnThree, self).__init__() 9570*da0073e9SAndroid Build Coastguard Worker 9571*da0073e9SAndroid Build Coastguard Worker def forward(self, *inputs): 9572*da0073e9SAndroid Build Coastguard Worker output = inputs[0] 9573*da0073e9SAndroid Build Coastguard Worker for i in range(1, len(inputs)): 9574*da0073e9SAndroid Build Coastguard Worker output += inputs[i] 9575*da0073e9SAndroid Build Coastguard Worker return output, output, output 9576*da0073e9SAndroid Build Coastguard Worker 9577*da0073e9SAndroid Build Coastguard Worker def test_script_star_assign(self): 9578*da0073e9SAndroid Build Coastguard Worker class M2(torch.jit.ScriptModule): 9579*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 9580*da0073e9SAndroid Build Coastguard Worker super().__init__() 9581*da0073e9SAndroid Build Coastguard Worker self.g = torch.jit.trace(TestScript.StarTestSumAndReturnThree(), torch.ones(4, 3)) 9582*da0073e9SAndroid Build Coastguard Worker self.define(''' 9583*da0073e9SAndroid Build Coastguard Worker def forward(self, rep): 9584*da0073e9SAndroid Build Coastguard Worker head, *tail = self.g(rep) 9585*da0073e9SAndroid Build Coastguard Worker return head 9586*da0073e9SAndroid Build Coastguard Worker ''') 9587*da0073e9SAndroid Build Coastguard Worker 9588*da0073e9SAndroid Build Coastguard Worker m = M2() 9589*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m(torch.zeros(4, 3)), 3 * torch.zeros(4, 3)) 9590*da0073e9SAndroid Build Coastguard Worker 9591*da0073e9SAndroid Build Coastguard Worker def test_script_module_star_assign2(self): 9592*da0073e9SAndroid Build Coastguard Worker class M2(torch.jit.ScriptModule): 9593*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 9594*da0073e9SAndroid Build Coastguard Worker super().__init__() 9595*da0073e9SAndroid Build Coastguard Worker self.g = torch.jit.trace( 9596*da0073e9SAndroid Build Coastguard Worker TestScript.StarTestSumAndReturnThree(), 9597*da0073e9SAndroid Build Coastguard Worker (torch.ones(4, 3), torch.ones(4, 3), torch.ones(4, 3)), 9598*da0073e9SAndroid Build Coastguard Worker _force_outplace=True) 9599*da0073e9SAndroid Build Coastguard Worker self.define(''' 9600*da0073e9SAndroid Build Coastguard Worker def forward(self, rep): 9601*da0073e9SAndroid Build Coastguard Worker *head, tail = self.g(rep, rep, rep) 9602*da0073e9SAndroid Build Coastguard Worker return tail 9603*da0073e9SAndroid Build Coastguard Worker ''') 9604*da0073e9SAndroid Build Coastguard Worker 9605*da0073e9SAndroid Build Coastguard Worker m = M2() 9606*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m(torch.ones(4, 3)), 3 * torch.ones(4, 3)) 9607*da0073e9SAndroid Build Coastguard Worker 9608*da0073e9SAndroid Build Coastguard Worker def test_script_module_star_assign2_inplace(self): 9609*da0073e9SAndroid Build Coastguard Worker class M2(torch.jit.ScriptModule): 9610*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 9611*da0073e9SAndroid Build Coastguard Worker super().__init__() 9612*da0073e9SAndroid Build Coastguard Worker self.g = torch.jit.trace( 9613*da0073e9SAndroid Build Coastguard Worker TestScript.StarTestSumAndReturnThree(), 9614*da0073e9SAndroid Build Coastguard Worker (torch.ones(4, 3), torch.ones(4, 3), torch.ones(4, 3)), 9615*da0073e9SAndroid Build Coastguard Worker _force_outplace=False) 9616*da0073e9SAndroid Build Coastguard Worker self.define(''' 9617*da0073e9SAndroid Build Coastguard Worker def forward(self, rep): 9618*da0073e9SAndroid Build Coastguard Worker *head, tail = self.g(rep, rep, rep) 9619*da0073e9SAndroid Build Coastguard Worker return tail 9620*da0073e9SAndroid Build Coastguard Worker ''') 9621*da0073e9SAndroid Build Coastguard Worker 9622*da0073e9SAndroid Build Coastguard Worker m = M2() 9623*da0073e9SAndroid Build Coastguard Worker # since forward() makes three aliases to the input `rep` before passing 9624*da0073e9SAndroid Build Coastguard Worker # it to StarTestSumAndReturnThree(), in-place behavior will be different 9625*da0073e9SAndroid Build Coastguard Worker # than the above out of place. 9626*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m(torch.ones(4, 3)), 4 * torch.ones(4, 3)) 9627*da0073e9SAndroid Build Coastguard Worker 9628*da0073e9SAndroid Build Coastguard Worker def test_script_module_star_assign_fail_pythonop(self): 9629*da0073e9SAndroid Build Coastguard Worker 9630*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "cannot be used as a tuple"): 9631*da0073e9SAndroid Build Coastguard Worker class M2(torch.jit.ScriptModule): 9632*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 9633*da0073e9SAndroid Build Coastguard Worker super().__init__() 9634*da0073e9SAndroid Build Coastguard Worker 9635*da0073e9SAndroid Build Coastguard Worker @torch.jit.ignore 9636*da0073e9SAndroid Build Coastguard Worker def myfunc(): 9637*da0073e9SAndroid Build Coastguard Worker return torch.zeros(1, 2, 3), torch.zeros(1, 2, 3) 9638*da0073e9SAndroid Build Coastguard Worker 9639*da0073e9SAndroid Build Coastguard Worker self.define(''' 9640*da0073e9SAndroid Build Coastguard Worker def forward(self, rep): 9641*da0073e9SAndroid Build Coastguard Worker a, *b = myfunc() 9642*da0073e9SAndroid Build Coastguard Worker return a 9643*da0073e9SAndroid Build Coastguard Worker ''') 9644*da0073e9SAndroid Build Coastguard Worker 9645*da0073e9SAndroid Build Coastguard Worker m = M2() 9646*da0073e9SAndroid Build Coastguard Worker m(torch.zeros(4, 3)) 9647*da0073e9SAndroid Build Coastguard Worker 9648*da0073e9SAndroid Build Coastguard Worker def test_script_module_star_assign_fail_builtin(self): 9649*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "cannot be used as a tuple"): 9650*da0073e9SAndroid Build Coastguard Worker class M2(torch.jit.ScriptModule): 9651*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 9652*da0073e9SAndroid Build Coastguard Worker super().__init__() 9653*da0073e9SAndroid Build Coastguard Worker 9654*da0073e9SAndroid Build Coastguard Worker self.define(''' 9655*da0073e9SAndroid Build Coastguard Worker def forward(self, rep): 9656*da0073e9SAndroid Build Coastguard Worker a, *b = torch.neg(rep) 9657*da0073e9SAndroid Build Coastguard Worker return a 9658*da0073e9SAndroid Build Coastguard Worker ''') 9659*da0073e9SAndroid Build Coastguard Worker 9660*da0073e9SAndroid Build Coastguard Worker m = M2() 9661*da0073e9SAndroid Build Coastguard Worker m(torch.zeros(4, 3)) 9662*da0073e9SAndroid Build Coastguard Worker 9663*da0073e9SAndroid Build Coastguard Worker def test_script_pack_padded_sequence(self): 9664*da0073e9SAndroid Build Coastguard Worker from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence 9665*da0073e9SAndroid Build Coastguard Worker 9666*da0073e9SAndroid Build Coastguard Worker def pack_padded_pad_packed_script(x, seq_lens): 9667*da0073e9SAndroid Build Coastguard Worker x = pack_padded_sequence(x, seq_lens) 9668*da0073e9SAndroid Build Coastguard Worker x, lengths = pad_packed_sequence(x) 9669*da0073e9SAndroid Build Coastguard Worker return x, lengths 9670*da0073e9SAndroid Build Coastguard Worker 9671*da0073e9SAndroid Build Coastguard Worker T, B, C = 3, 5, 7 9672*da0073e9SAndroid Build Coastguard Worker x = torch.ones((T, B, C)) 9673*da0073e9SAndroid Build Coastguard Worker seq_lens = torch.tensor([3, 3, 2, 2, 1]) 9674*da0073e9SAndroid Build Coastguard Worker # set padding value so we can test equivalence 9675*da0073e9SAndroid Build Coastguard Worker for b in range(B): 9676*da0073e9SAndroid Build Coastguard Worker if seq_lens[b] < T: 9677*da0073e9SAndroid Build Coastguard Worker x[seq_lens[b]:, b, :] = 0 9678*da0073e9SAndroid Build Coastguard Worker 9679*da0073e9SAndroid Build Coastguard Worker eager_seq, eager_lengths = pack_padded_pad_packed_script(x, seq_lens) 9680*da0073e9SAndroid Build Coastguard Worker with torch._jit_internal._disable_emit_hooks(): 9681*da0073e9SAndroid Build Coastguard Worker scripted_pack_padded_seq = torch.jit.script(pack_padded_pad_packed_script) 9682*da0073e9SAndroid Build Coastguard Worker script_seq, script_lengths = scripted_pack_padded_seq(x, seq_lens) 9683*da0073e9SAndroid Build Coastguard Worker self.assertEqual(eager_seq, script_seq) 9684*da0073e9SAndroid Build Coastguard Worker self.assertEqual(eager_lengths, script_lengths) 9685*da0073e9SAndroid Build Coastguard Worker 9686*da0073e9SAndroid Build Coastguard Worker class ExperimentalLSTM(torch.nn.Module): 9687*da0073e9SAndroid Build Coastguard Worker def __init__(self, input_dim, hidden_dim): 9688*da0073e9SAndroid Build Coastguard Worker super().__init__() 9689*da0073e9SAndroid Build Coastguard Worker 9690*da0073e9SAndroid Build Coastguard Worker def forward(self, input): 9691*da0073e9SAndroid Build Coastguard Worker # type: (Tensor) 9692*da0073e9SAndroid Build Coastguard Worker packed = pack_padded_sequence( 9693*da0073e9SAndroid Build Coastguard Worker input=input, lengths=torch.tensor([1, 2]), enforce_sorted=False 9694*da0073e9SAndroid Build Coastguard Worker ) 9695*da0073e9SAndroid Build Coastguard Worker output, lengths = pad_packed_sequence( 9696*da0073e9SAndroid Build Coastguard Worker sequence=packed, total_length=2 9697*da0073e9SAndroid Build Coastguard Worker ) 9698*da0073e9SAndroid Build Coastguard Worker # lengths is flipped, so is output 9699*da0073e9SAndroid Build Coastguard Worker return output[0] 9700*da0073e9SAndroid Build Coastguard Worker 9701*da0073e9SAndroid Build Coastguard Worker lstm = ExperimentalLSTM(input_dim=2, hidden_dim=2) 9702*da0073e9SAndroid Build Coastguard Worker 9703*da0073e9SAndroid Build Coastguard Worker with torch._jit_internal._disable_emit_hooks(): 9704*da0073e9SAndroid Build Coastguard Worker self.checkModule(lstm, [torch.ones(2, 2)]) 9705*da0073e9SAndroid Build Coastguard Worker 9706*da0073e9SAndroid Build Coastguard Worker def test_script_pad_sequence_pack_sequence(self): 9707*da0073e9SAndroid Build Coastguard Worker from torch.nn.utils.rnn import pad_sequence, pack_sequence, pad_packed_sequence 9708*da0073e9SAndroid Build Coastguard Worker 9709*da0073e9SAndroid Build Coastguard Worker def pad_sequence_func(tensor_list, batch_first=False, padding_value=0.0, padding_side="right"): 9710*da0073e9SAndroid Build Coastguard Worker # type: (List[Tensor], bool, float, str) -> Tensor 9711*da0073e9SAndroid Build Coastguard Worker return pad_sequence(tensor_list, batch_first, padding_value, padding_side) 9712*da0073e9SAndroid Build Coastguard Worker 9713*da0073e9SAndroid Build Coastguard Worker def pack_sequence_func(tensor_list, enforce_sorted=True): 9714*da0073e9SAndroid Build Coastguard Worker # type: (List[Tensor], bool) -> Tensor 9715*da0073e9SAndroid Build Coastguard Worker return pad_packed_sequence(pack_sequence(tensor_list, enforce_sorted))[0] 9716*da0073e9SAndroid Build Coastguard Worker 9717*da0073e9SAndroid Build Coastguard Worker ones3 = torch.ones(3, 5) 9718*da0073e9SAndroid Build Coastguard Worker ones4 = torch.ones(4, 5) 9719*da0073e9SAndroid Build Coastguard Worker ones5 = torch.ones(5, 5) 9720*da0073e9SAndroid Build Coastguard Worker tensor1 = torch.tensor([1, 2, 3]) 9721*da0073e9SAndroid Build Coastguard Worker tensor2 = torch.tensor([4, 5]) 9722*da0073e9SAndroid Build Coastguard Worker tensor3 = torch.tensor([6]) 9723*da0073e9SAndroid Build Coastguard Worker with torch._jit_internal._disable_emit_hooks(): 9724*da0073e9SAndroid Build Coastguard Worker self.checkScript(pad_sequence_func, 9725*da0073e9SAndroid Build Coastguard Worker ([ones3, ones4, ones5],)) 9726*da0073e9SAndroid Build Coastguard Worker self.checkScript(pad_sequence_func, 9727*da0073e9SAndroid Build Coastguard Worker ([ones3, ones4, ones5], True)) 9728*da0073e9SAndroid Build Coastguard Worker self.checkScript(pad_sequence_func, 9729*da0073e9SAndroid Build Coastguard Worker ([ones3, ones4, ones5], True, 2.5)) 9730*da0073e9SAndroid Build Coastguard Worker self.checkScript(pad_sequence_func, 9731*da0073e9SAndroid Build Coastguard Worker ([ones3, ones4, ones5], True, 2.5, "left")) 9732*da0073e9SAndroid Build Coastguard Worker self.checkScript(pad_sequence_func, 9733*da0073e9SAndroid Build Coastguard Worker ([ones3, ones4, ones5], False, 2.5, "left")) 9734*da0073e9SAndroid Build Coastguard Worker self.checkScript(pack_sequence_func, 9735*da0073e9SAndroid Build Coastguard Worker ([tensor1, tensor2, tensor3],)) 9736*da0073e9SAndroid Build Coastguard Worker self.checkScript(pack_sequence_func, 9737*da0073e9SAndroid Build Coastguard Worker ([tensor1, tensor2, tensor3], False)) 9738*da0073e9SAndroid Build Coastguard Worker 9739*da0073e9SAndroid Build Coastguard Worker def test_script_get_tracing_state(self): 9740*da0073e9SAndroid Build Coastguard Worker def test_if_tracing(x): 9741*da0073e9SAndroid Build Coastguard Worker if torch._C._get_tracing_state(): 9742*da0073e9SAndroid Build Coastguard Worker return x + 1 9743*da0073e9SAndroid Build Coastguard Worker else: 9744*da0073e9SAndroid Build Coastguard Worker return x - 1 9745*da0073e9SAndroid Build Coastguard Worker 9746*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(3, 3) 9747*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_if_tracing, (inp,)) 9748*da0073e9SAndroid Build Coastguard Worker 9749*da0073e9SAndroid Build Coastguard Worker def test_script_is_tracing(self): 9750*da0073e9SAndroid Build Coastguard Worker def test_is_tracing(x): 9751*da0073e9SAndroid Build Coastguard Worker if torch.jit.is_tracing(): 9752*da0073e9SAndroid Build Coastguard Worker return x + 1 9753*da0073e9SAndroid Build Coastguard Worker else: 9754*da0073e9SAndroid Build Coastguard Worker return x - 1 9755*da0073e9SAndroid Build Coastguard Worker 9756*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(3, 3) 9757*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_is_tracing, (inp,)) 9758*da0073e9SAndroid Build Coastguard Worker 9759*da0073e9SAndroid Build Coastguard Worker def test_is_scripting(self): 9760*da0073e9SAndroid Build Coastguard Worker def foo(): 9761*da0073e9SAndroid Build Coastguard Worker return torch.jit.is_scripting() 9762*da0073e9SAndroid Build Coastguard Worker 9763*da0073e9SAndroid Build Coastguard Worker self.assertFalse(foo()) 9764*da0073e9SAndroid Build Coastguard Worker scripted = torch.jit.script(foo) 9765*da0073e9SAndroid Build Coastguard Worker self.assertTrue(scripted()) 9766*da0073e9SAndroid Build Coastguard Worker 9767*da0073e9SAndroid Build Coastguard Worker def test_comment_ignore_indent(self): 9768*da0073e9SAndroid Build Coastguard Worker class Model(torch.nn.Module): 9769*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 9770*da0073e9SAndroid Build Coastguard Worker # useless comment that is not indented correctly # noqa: E115 9771*da0073e9SAndroid Build Coastguard Worker super().__init__() 9772*da0073e9SAndroid Build Coastguard Worker 9773*da0073e9SAndroid Build Coastguard Worker def forward(self): 9774*da0073e9SAndroid Build Coastguard Worker return 5 9775*da0073e9SAndroid Build Coastguard Worker 9776*da0073e9SAndroid Build Coastguard Worker # should compile without an error 9777*da0073e9SAndroid Build Coastguard Worker self.checkModule(Model(), ()) 9778*da0073e9SAndroid Build Coastguard Worker 9779*da0073e9SAndroid Build Coastguard Worker def test_script_outputs(self): 9780*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "cannot be used as a tuple"): 9781*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 9782*da0073e9SAndroid Build Coastguard Worker def foo(a): 9783*da0073e9SAndroid Build Coastguard Worker c, d = a + a 9784*da0073e9SAndroid Build Coastguard Worker return c + d 9785*da0073e9SAndroid Build Coastguard Worker 9786*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 9787*da0073e9SAndroid Build Coastguard Worker def return3(): 9788*da0073e9SAndroid Build Coastguard Worker return 1, 2, 3 9789*da0073e9SAndroid Build Coastguard Worker 9790*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "too many values to unpack"): 9791*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 9792*da0073e9SAndroid Build Coastguard Worker def bind2(): 9793*da0073e9SAndroid Build Coastguard Worker a, b = return3() 9794*da0073e9SAndroid Build Coastguard Worker print(a) 9795*da0073e9SAndroid Build Coastguard Worker print(b) 9796*da0073e9SAndroid Build Coastguard Worker 9797*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not RUN_CUDA, "requires CUDA") 9798*da0073e9SAndroid Build Coastguard Worker def test_script_get_device_cuda(self): 9799*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 9800*da0073e9SAndroid Build Coastguard Worker def foo(a): 9801*da0073e9SAndroid Build Coastguard Worker return a.get_device() 9802*da0073e9SAndroid Build Coastguard Worker 9803*da0073e9SAndroid Build Coastguard Worker v = torch.randn(1, device='cuda') 9804*da0073e9SAndroid Build Coastguard Worker self.assertEqual(foo(v), 0) 9805*da0073e9SAndroid Build Coastguard Worker 9806*da0073e9SAndroid Build Coastguard Worker def test_script_chunk(self): 9807*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 9808*da0073e9SAndroid Build Coastguard Worker def foo(a): 9809*da0073e9SAndroid Build Coastguard Worker b, c = torch.chunk(a, dim=0, chunks=2) 9810*da0073e9SAndroid Build Coastguard Worker return b 9811*da0073e9SAndroid Build Coastguard Worker v = torch.rand(10, 3) 9812*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.chunk(v, dim=0, chunks=2)[0], foo(v)) 9813*da0073e9SAndroid Build Coastguard Worker 9814*da0073e9SAndroid Build Coastguard Worker def test_script_copy(self): 9815*da0073e9SAndroid Build Coastguard Worker class M(torch.nn.Module): 9816*da0073e9SAndroid Build Coastguard Worker __annotations__ = { 9817*da0073e9SAndroid Build Coastguard Worker "val": Optional[torch.Tensor] 9818*da0073e9SAndroid Build Coastguard Worker } 9819*da0073e9SAndroid Build Coastguard Worker 9820*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 9821*da0073e9SAndroid Build Coastguard Worker super().__init__() 9822*da0073e9SAndroid Build Coastguard Worker self.val = None 9823*da0073e9SAndroid Build Coastguard Worker 9824*da0073e9SAndroid Build Coastguard Worker def some_method(self): 9825*da0073e9SAndroid Build Coastguard Worker return 3 9826*da0073e9SAndroid Build Coastguard Worker 9827*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 9828*da0073e9SAndroid Build Coastguard Worker # type: (Tensor) -> Tensor 9829*da0073e9SAndroid Build Coastguard Worker self.val = x + self.some_method() 9830*da0073e9SAndroid Build Coastguard Worker return x 9831*da0073e9SAndroid Build Coastguard Worker 9832*da0073e9SAndroid Build Coastguard Worker m = torch.jit.script(M()) 9833*da0073e9SAndroid Build Coastguard Worker # test copy 9834*da0073e9SAndroid Build Coastguard Worker copy.copy(m) 9835*da0073e9SAndroid Build Coastguard Worker copy.deepcopy(m) 9836*da0073e9SAndroid Build Coastguard Worker 9837*da0073e9SAndroid Build Coastguard Worker def test_script_forward_method_replacement(self): 9838*da0073e9SAndroid Build Coastguard Worker # We want to support the use case of attaching a different `forward` method 9839*da0073e9SAndroid Build Coastguard Worker class LowLevelModule(torch.nn.Module): 9840*da0073e9SAndroid Build Coastguard Worker def forward(self, input: torch.Tensor): 9841*da0073e9SAndroid Build Coastguard Worker # Generic forward dispatch 9842*da0073e9SAndroid Build Coastguard Worker return self.forward_pytorch(input) * 2 9843*da0073e9SAndroid Build Coastguard Worker 9844*da0073e9SAndroid Build Coastguard Worker class TestModule(LowLevelModule): 9845*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 9846*da0073e9SAndroid Build Coastguard Worker super().__init__() 9847*da0073e9SAndroid Build Coastguard Worker # Replace the forward method 9848*da0073e9SAndroid Build Coastguard Worker self.forward = types.MethodType(LowLevelModule.forward, self) 9849*da0073e9SAndroid Build Coastguard Worker 9850*da0073e9SAndroid Build Coastguard Worker def forward_pytorch(self, input: torch.Tensor): 9851*da0073e9SAndroid Build Coastguard Worker return torch.tensor(123) 9852*da0073e9SAndroid Build Coastguard Worker 9853*da0073e9SAndroid Build Coastguard Worker def forward(self, input: torch.Tensor): 9854*da0073e9SAndroid Build Coastguard Worker # Should not use this forward method 9855*da0073e9SAndroid Build Coastguard Worker raise AssertionError("This method should not be used") 9856*da0073e9SAndroid Build Coastguard Worker return self.forward_pytorch(input) 9857*da0073e9SAndroid Build Coastguard Worker 9858*da0073e9SAndroid Build Coastguard Worker m = TestModule() 9859*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m(torch.tensor(1)), torch.tensor(246)) 9860*da0073e9SAndroid Build Coastguard Worker 9861*da0073e9SAndroid Build Coastguard Worker m_scripted = torch.jit.script(m) 9862*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m_scripted(torch.tensor(1)), torch.tensor(246)) 9863*da0073e9SAndroid Build Coastguard Worker 9864*da0073e9SAndroid Build Coastguard Worker def test_python_call_non_tensor(self): 9865*da0073e9SAndroid Build Coastguard Worker def foo(a, b, c): 9866*da0073e9SAndroid Build Coastguard Worker # type: (Tensor, int, Tuple[Tensor, int]) -> Tuple[int, Tensor] 9867*da0073e9SAndroid Build Coastguard Worker d, e = c 9868*da0073e9SAndroid Build Coastguard Worker return b + e, a + d 9869*da0073e9SAndroid Build Coastguard Worker 9870*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 9871*da0073e9SAndroid Build Coastguard Worker def bar(): 9872*da0073e9SAndroid Build Coastguard Worker x = torch.ones(3, 4) 9873*da0073e9SAndroid Build Coastguard Worker a, b = foo(x, 3, (x, 3)) 9874*da0073e9SAndroid Build Coastguard Worker return a, b 9875*da0073e9SAndroid Build Coastguard Worker 9876*da0073e9SAndroid Build Coastguard Worker self.assertEqual((6, torch.ones(3, 4) + 1), bar()) 9877*da0073e9SAndroid Build Coastguard Worker 9878*da0073e9SAndroid Build Coastguard Worker def test_python_call_non_tensor_wrong(self): 9879*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r"but instead got value of type tuple"): 9880*da0073e9SAndroid Build Coastguard Worker @torch.jit.ignore 9881*da0073e9SAndroid Build Coastguard Worker def foo(): 9882*da0073e9SAndroid Build Coastguard Worker # type: () -> Tensor 9883*da0073e9SAndroid Build Coastguard Worker return ((3, 4),) # noqa: T484 9884*da0073e9SAndroid Build Coastguard Worker 9885*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 9886*da0073e9SAndroid Build Coastguard Worker def bar(): 9887*da0073e9SAndroid Build Coastguard Worker return foo() 9888*da0073e9SAndroid Build Coastguard Worker 9889*da0073e9SAndroid Build Coastguard Worker bar() 9890*da0073e9SAndroid Build Coastguard Worker 9891*da0073e9SAndroid Build Coastguard Worker def test_if_different_type(self): 9892*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "c0 is set to type " 9893*da0073e9SAndroid Build Coastguard Worker "int in the true branch and type " 9894*da0073e9SAndroid Build Coastguard Worker "float in the false branch"): 9895*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 9896*da0073e9SAndroid Build Coastguard Worker def diff_type_used(): 9897*da0073e9SAndroid Build Coastguard Worker if 1 == 2: 9898*da0073e9SAndroid Build Coastguard Worker c0 = 1 9899*da0073e9SAndroid Build Coastguard Worker else: 9900*da0073e9SAndroid Build Coastguard Worker c0 = 1.0 9901*da0073e9SAndroid Build Coastguard Worker return c0 9902*da0073e9SAndroid Build Coastguard Worker 9903*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Variable 'c0' previously had type float"): 9904*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 9905*da0073e9SAndroid Build Coastguard Worker def diff_existing_type(x): 9906*da0073e9SAndroid Build Coastguard Worker c0 = 1.0 9907*da0073e9SAndroid Build Coastguard Worker if 1 == 2: 9908*da0073e9SAndroid Build Coastguard Worker c0 = 1 9909*da0073e9SAndroid Build Coastguard Worker print(x) 9910*da0073e9SAndroid Build Coastguard Worker return x 9911*da0073e9SAndroid Build Coastguard Worker 9912*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 9913*da0073e9SAndroid Build Coastguard Worker def diff_type_unused(): 9914*da0073e9SAndroid Build Coastguard Worker if 1 == 1: 9915*da0073e9SAndroid Build Coastguard Worker c0 = 1 9916*da0073e9SAndroid Build Coastguard Worker print(c0) 9917*da0073e9SAndroid Build Coastguard Worker else: 9918*da0073e9SAndroid Build Coastguard Worker c0 = 1.0 9919*da0073e9SAndroid Build Coastguard Worker print(c0) 9920*da0073e9SAndroid Build Coastguard Worker return 1 9921*da0073e9SAndroid Build Coastguard Worker 9922*da0073e9SAndroid Build Coastguard Worker def test_if_not_defined_error(self): 9923*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "c0 is not defined in the false branch"): 9924*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 9925*da0073e9SAndroid Build Coastguard Worker def test(): 9926*da0073e9SAndroid Build Coastguard Worker if 1 == 1: 9927*da0073e9SAndroid Build Coastguard Worker c0 = 1 9928*da0073e9SAndroid Build Coastguard Worker return c0 9929*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "c0 is not defined in the true branch"): 9930*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 9931*da0073e9SAndroid Build Coastguard Worker def test2(): 9932*da0073e9SAndroid Build Coastguard Worker if 1 == 1: 9933*da0073e9SAndroid Build Coastguard Worker pass 9934*da0073e9SAndroid Build Coastguard Worker else: 9935*da0073e9SAndroid Build Coastguard Worker c0 = 1 9936*da0073e9SAndroid Build Coastguard Worker return c0 9937*da0073e9SAndroid Build Coastguard Worker 9938*da0073e9SAndroid Build Coastguard Worker def test_if_list_cat(self): 9939*da0073e9SAndroid Build Coastguard Worker # testing that different length lists don't throw error on cat in shape prop 9940*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 9941*da0073e9SAndroid Build Coastguard Worker def test_list(x): 9942*da0073e9SAndroid Build Coastguard Worker if bool(x.sum() < 1): 9943*da0073e9SAndroid Build Coastguard Worker c = [x, x] 9944*da0073e9SAndroid Build Coastguard Worker else: 9945*da0073e9SAndroid Build Coastguard Worker c = [x, x, x] 9946*da0073e9SAndroid Build Coastguard Worker return torch.cat(c) 9947*da0073e9SAndroid Build Coastguard Worker 9948*da0073e9SAndroid Build Coastguard Worker b = torch.zeros(2, 4) 9949*da0073e9SAndroid Build Coastguard Worker _propagate_shapes(test_list.graph, (b,), False) 9950*da0073e9SAndroid Build Coastguard Worker 9951*da0073e9SAndroid Build Coastguard Worker def test_if_supertype(self): 9952*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 9953*da0073e9SAndroid Build Coastguard Worker def tensor_unifying(x, y, z): 9954*da0073e9SAndroid Build Coastguard Worker # testing dynamic is appropriately set for y and z 9955*da0073e9SAndroid Build Coastguard Worker if bool(x): 9956*da0073e9SAndroid Build Coastguard Worker x, y, z = x + 1, y, z 9957*da0073e9SAndroid Build Coastguard Worker else: 9958*da0073e9SAndroid Build Coastguard Worker x, y, z = x + 1, x, y 9959*da0073e9SAndroid Build Coastguard Worker 9960*da0073e9SAndroid Build Coastguard Worker return x, y, z 9961*da0073e9SAndroid Build Coastguard Worker 9962*da0073e9SAndroid Build Coastguard Worker a = torch.zeros(2, 2, dtype=torch.float) 9963*da0073e9SAndroid Build Coastguard Worker b = torch.zeros(2, 4, dtype=torch.long) 9964*da0073e9SAndroid Build Coastguard Worker c = torch.zeros(2, 4, dtype=torch.float) 9965*da0073e9SAndroid Build Coastguard Worker 9966*da0073e9SAndroid Build Coastguard Worker graph = _propagate_shapes(tensor_unifying.graph, (a, b, c), False) 9967*da0073e9SAndroid Build Coastguard Worker if_outputs = list(graph.findNode("prim::If").outputs()) 9968*da0073e9SAndroid Build Coastguard Worker self.assertTrue(if_outputs[0].type().str() == "Float(*, *, requires_grad=0, device=cpu)") 9969*da0073e9SAndroid Build Coastguard Worker self.assertTrue(if_outputs[1].type().str() == "Tensor(*, *, requires_grad=0, device=cpu)") 9970*da0073e9SAndroid Build Coastguard Worker self.assertTrue(if_outputs[2].type().str() == "Tensor(*, *, requires_grad=0, device=cpu)") 9971*da0073e9SAndroid Build Coastguard Worker 9972*da0073e9SAndroid Build Coastguard Worker def test_list_unify(self): 9973*da0073e9SAndroid Build Coastguard Worker # allowing a unififed int?[] would cause a runtime error b/c 9974*da0073e9SAndroid Build Coastguard Worker # the index operation expects int?[] to be a generic list, 9975*da0073e9SAndroid Build Coastguard Worker # but in the true branch the IValue will be a int list 9976*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "int[] in the true branch and type None[]"): 9977*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 9978*da0073e9SAndroid Build Coastguard Worker def list_optional_fails(x): 9979*da0073e9SAndroid Build Coastguard Worker # type: (bool) -> Optional[int] 9980*da0073e9SAndroid Build Coastguard Worker if x: 9981*da0073e9SAndroid Build Coastguard Worker y = [1] 9982*da0073e9SAndroid Build Coastguard Worker else: 9983*da0073e9SAndroid Build Coastguard Worker y = [None] # noqa: T484 9984*da0073e9SAndroid Build Coastguard Worker return y[0] 9985*da0073e9SAndroid Build Coastguard Worker 9986*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 9987*da0073e9SAndroid Build Coastguard Worker def list_tensors(x): 9988*da0073e9SAndroid Build Coastguard Worker # type: (bool) -> Tuple[Tensor, List[Tensor]] 9989*da0073e9SAndroid Build Coastguard Worker if x: 9990*da0073e9SAndroid Build Coastguard Worker a = torch.zeros([1, 1]) 9991*da0073e9SAndroid Build Coastguard Worker y = [a] 9992*da0073e9SAndroid Build Coastguard Worker else: 9993*da0073e9SAndroid Build Coastguard Worker a = torch.zeros([1, 2]) 9994*da0073e9SAndroid Build Coastguard Worker y = [a] 9995*da0073e9SAndroid Build Coastguard Worker return a, y 9996*da0073e9SAndroid Build Coastguard Worker 9997*da0073e9SAndroid Build Coastguard Worker self.run_pass('constant_propagation', list_tensors.graph) 9998*da0073e9SAndroid Build Coastguard Worker m = self.createFunctionFromGraph(list_tensors.graph) 9999*da0073e9SAndroid Build Coastguard Worker # testing that tensor type of lists is unified 10000*da0073e9SAndroid Build Coastguard Worker self.getExportImportCopy(m) 10001*da0073e9SAndroid Build Coastguard Worker 10002*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("Not a TorchDynamo suitable test") 10003*da0073e9SAndroid Build Coastguard Worker @_inline_everything 10004*da0073e9SAndroid Build Coastguard Worker def test_import_constants_not_specialized(self): 10005*da0073e9SAndroid Build Coastguard Worker class Mod(torch.nn.Module): 10006*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 10007*da0073e9SAndroid Build Coastguard Worker return torch.cat(2 * [x], dim=0) 10008*da0073e9SAndroid Build Coastguard Worker 10009*da0073e9SAndroid Build Coastguard Worker class ScriptMod(torch.jit.ScriptModule): 10010*da0073e9SAndroid Build Coastguard Worker def __init__(self, mod): 10011*da0073e9SAndroid Build Coastguard Worker super().__init__() 10012*da0073e9SAndroid Build Coastguard Worker x = torch.zeros(1, 3) 10013*da0073e9SAndroid Build Coastguard Worker mod_fn = lambda : mod(x) # noqa: E731 10014*da0073e9SAndroid Build Coastguard Worker self.mod = torch.jit.trace(mod_fn, ()) 10015*da0073e9SAndroid Build Coastguard Worker 10016*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 10017*da0073e9SAndroid Build Coastguard Worker def forward(self): 10018*da0073e9SAndroid Build Coastguard Worker return self.mod() 10019*da0073e9SAndroid Build Coastguard Worker 10020*da0073e9SAndroid Build Coastguard Worker cm = ScriptMod(Mod()) 10021*da0073e9SAndroid Build Coastguard Worker # specialized tensor in graph 10022*da0073e9SAndroid Build Coastguard Worker FileCheck().check("Float(1, 3, strides=[3, 1], requires_grad=0, device=cpu)").run(cm.forward.graph) 10023*da0073e9SAndroid Build Coastguard Worker buffer = io.BytesIO() 10024*da0073e9SAndroid Build Coastguard Worker torch.jit.save(cm, buffer) 10025*da0073e9SAndroid Build Coastguard Worker buffer.seek(0) 10026*da0073e9SAndroid Build Coastguard Worker # when tensor is loaded as constant it isnt specialized 10027*da0073e9SAndroid Build Coastguard Worker cm_load = torch.jit.load(buffer) 10028*da0073e9SAndroid Build Coastguard Worker FileCheck().check_not("Float(1, 3)").run(cm_load.forward.graph) 10029*da0073e9SAndroid Build Coastguard Worker 10030*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("TorchDynamo fails with unknown reason") 10031*da0073e9SAndroid Build Coastguard Worker def test_type_annotations_repeated_list(self): 10032*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 10033*da0073e9SAndroid Build Coastguard Worker def float_fn(x, y): 10034*da0073e9SAndroid Build Coastguard Worker # type: (float, BroadcastingList3[float]) -> List[float] 10035*da0073e9SAndroid Build Coastguard Worker return y 10036*da0073e9SAndroid Build Coastguard Worker self.assertEqual(float_fn(2.0, 1.0), float_fn(2.0, [1.0, 1.0, 1.0])) 10037*da0073e9SAndroid Build Coastguard Worker self.assertEqual(float_fn(2.0, 1.0), float_fn(2.0, (1.0, 1.0, 1.0))) 10038*da0073e9SAndroid Build Coastguard Worker 10039*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 10040*da0073e9SAndroid Build Coastguard Worker def float_fn_call(): 10041*da0073e9SAndroid Build Coastguard Worker print(float_fn(1.0, 1.0)) 10042*da0073e9SAndroid Build Coastguard Worker print(float_fn(1.0, (1.0, 1.0, 1.0))) 10043*da0073e9SAndroid Build Coastguard Worker 10044*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 10045*da0073e9SAndroid Build Coastguard Worker def int_fn(x): 10046*da0073e9SAndroid Build Coastguard Worker # type: (BroadcastingList3[int]) -> List[int] 10047*da0073e9SAndroid Build Coastguard Worker return x 10048*da0073e9SAndroid Build Coastguard Worker self.assertEqual(int_fn(1), int_fn([1, 1, 1])) 10049*da0073e9SAndroid Build Coastguard Worker self.assertEqual(int_fn(1), int_fn((1, 1, 1))) 10050*da0073e9SAndroid Build Coastguard Worker 10051*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 10052*da0073e9SAndroid Build Coastguard Worker def int_fn_call(): 10053*da0073e9SAndroid Build Coastguard Worker print(int_fn(1)) 10054*da0073e9SAndroid Build Coastguard Worker print(int_fn((1, 1, 1))) 10055*da0073e9SAndroid Build Coastguard Worker 10056*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "must be a positive integer:"): 10057*da0073e9SAndroid Build Coastguard Worker @torch.jit.script # noqa: T484 10058*da0073e9SAndroid Build Coastguard Worker def fn(x): 10059*da0073e9SAndroid Build Coastguard Worker # type: (BroadcastingListx[int]) -> List[int] # noqa: T484 10060*da0073e9SAndroid Build Coastguard Worker return x 10061*da0073e9SAndroid Build Coastguard Worker 10062*da0073e9SAndroid Build Coastguard Worker # using CU so that flake8 error on int[2] is not raised (noqa not working) 10063*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Unknown type constructor"): 10064*da0073e9SAndroid Build Coastguard Worker cu = torch.jit.CompilationUnit(''' 10065*da0073e9SAndroid Build Coastguard Worker def nested(x, y): 10066*da0073e9SAndroid Build Coastguard Worker # type: (int, Tuple[int, int[2]]) -> List[int] 10067*da0073e9SAndroid Build Coastguard Worker return x # noqa: T484 10068*da0073e9SAndroid Build Coastguard Worker ''') 10069*da0073e9SAndroid Build Coastguard Worker 10070*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 10071*da0073e9SAndroid Build Coastguard Worker def f(x: BroadcastingList2[int]): 10072*da0073e9SAndroid Build Coastguard Worker return x 10073*da0073e9SAndroid Build Coastguard Worker 10074*da0073e9SAndroid Build Coastguard Worker out = f(1) 10075*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(out[0], int)) 10076*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, [1, 1]) 10077*da0073e9SAndroid Build Coastguard Worker 10078*da0073e9SAndroid Build Coastguard Worker def test_ntuple_builtins(self): 10079*da0073e9SAndroid Build Coastguard Worker from torch.nn.modules.utils import _single, _pair, _triple, _quadruple 10080*da0073e9SAndroid Build Coastguard Worker 10081*da0073e9SAndroid Build Coastguard Worker def test_ints(): 10082*da0073e9SAndroid Build Coastguard Worker return _single(1), _pair(2), _triple(3), _quadruple(4) 10083*da0073e9SAndroid Build Coastguard Worker 10084*da0073e9SAndroid Build Coastguard Worker def test_floats(): 10085*da0073e9SAndroid Build Coastguard Worker return _single(1), _pair(2.1), _triple(3.1), _quadruple(4.1) 10086*da0073e9SAndroid Build Coastguard Worker 10087*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_ints, ()) 10088*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_floats, ()) 10089*da0073e9SAndroid Build Coastguard Worker 10090*da0073e9SAndroid Build Coastguard Worker def test_embedding_renorm_grad_error(self): 10091*da0073e9SAndroid Build Coastguard Worker # Testing that the builtin call to embedding_renorm_ correctly throws 10092*da0073e9SAndroid Build Coastguard Worker # Error when .backward() is called on its input 10093*da0073e9SAndroid Build Coastguard Worker 10094*da0073e9SAndroid Build Coastguard Worker def embedding_norm(input, embedding_matrix, max_norm): 10095*da0073e9SAndroid Build Coastguard Worker F.embedding(input, embedding_matrix, max_norm=0.01) 10096*da0073e9SAndroid Build Coastguard Worker 10097*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 10098*da0073e9SAndroid Build Coastguard Worker def embedding_norm_script(input, embedding_matrix, max_norm): 10099*da0073e9SAndroid Build Coastguard Worker # type: (Tensor, Tensor, float) -> None 10100*da0073e9SAndroid Build Coastguard Worker F.embedding(input, embedding_matrix, max_norm=0.01) 10101*da0073e9SAndroid Build Coastguard Worker 10102*da0073e9SAndroid Build Coastguard Worker for _ in [embedding_norm, embedding_norm_script]: 10103*da0073e9SAndroid Build Coastguard Worker input = torch.tensor([[1, 2, 4, 5], [4, 3, 2, 9]]) 10104*da0073e9SAndroid Build Coastguard Worker embedding_matrix = torch.randn(10, 3) 10105*da0073e9SAndroid Build Coastguard Worker 10106*da0073e9SAndroid Build Coastguard Worker var1 = torch.randn(10, 3, requires_grad=True) 10107*da0073e9SAndroid Build Coastguard Worker var2 = var1.detach().requires_grad_() 10108*da0073e9SAndroid Build Coastguard Worker output1 = var1 * embedding_matrix 10109*da0073e9SAndroid Build Coastguard Worker output2 = var2 * embedding_matrix 10110*da0073e9SAndroid Build Coastguard Worker 10111*da0073e9SAndroid Build Coastguard Worker output1.sum().backward() 10112*da0073e9SAndroid Build Coastguard Worker 10113*da0073e9SAndroid Build Coastguard Worker ignore = F.embedding(input, embedding_matrix, max_norm=0.01) 10114*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "modified"): 10115*da0073e9SAndroid Build Coastguard Worker output2.sum().backward() 10116*da0073e9SAndroid Build Coastguard Worker 10117*da0073e9SAndroid Build Coastguard Worker def test_type_annotations(self): 10118*da0073e9SAndroid Build Coastguard Worker def fn(x, y): 10119*da0073e9SAndroid Build Coastguard Worker # type: (Tensor, Tensor) -> Tuple[Tensor, Tensor, Tensor] 10120*da0073e9SAndroid Build Coastguard Worker return x, x * 2, x * 3 10121*da0073e9SAndroid Build Coastguard Worker 10122*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r"need 4 values .* found only 3"): 10123*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 10124*da0073e9SAndroid Build Coastguard Worker def script_fn(x): 10125*da0073e9SAndroid Build Coastguard Worker x, y, z, w = fn(x, x) 10126*da0073e9SAndroid Build Coastguard Worker 10127*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r"too many values .* need 2 but found 3"): 10128*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 10129*da0073e9SAndroid Build Coastguard Worker def script_fn2(x): 10130*da0073e9SAndroid Build Coastguard Worker x, y = fn(x, x) 10131*da0073e9SAndroid Build Coastguard Worker 10132*da0073e9SAndroid Build Coastguard Worker def fn_unpack(x): 10133*da0073e9SAndroid Build Coastguard Worker y, z, w = fn(x, x) 10134*da0073e9SAndroid Build Coastguard Worker return y 10135*da0073e9SAndroid Build Coastguard Worker 10136*da0073e9SAndroid Build Coastguard Worker def fn_index(x): 10137*da0073e9SAndroid Build Coastguard Worker q = fn(x, x) 10138*da0073e9SAndroid Build Coastguard Worker return x 10139*da0073e9SAndroid Build Coastguard Worker 10140*da0073e9SAndroid Build Coastguard Worker def fn_string(str, strpair): 10141*da0073e9SAndroid Build Coastguard Worker # type: (str, Tuple[str, str]) -> Tuple[str, int, str, str] 10142*da0073e9SAndroid Build Coastguard Worker str1, str2 = strpair 10143*da0073e9SAndroid Build Coastguard Worker return str, 2, str1, str2 10144*da0073e9SAndroid Build Coastguard Worker 10145*da0073e9SAndroid Build Coastguard Worker x = torch.ones(2, 2) 10146*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn_unpack, (x,), optimize=True) 10147*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn_index, (x,), optimize=True) 10148*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn_string, ("1", ("3", "4")), optimize=True) 10149*da0073e9SAndroid Build Coastguard Worker 10150*da0073e9SAndroid Build Coastguard Worker def test_type_annotations_varargs(self): 10151*da0073e9SAndroid Build Coastguard Worker @torch.jit.ignore 10152*da0073e9SAndroid Build Coastguard Worker def fn_varargs(x, *args): 10153*da0073e9SAndroid Build Coastguard Worker return args[0] if args else x 10154*da0073e9SAndroid Build Coastguard Worker 10155*da0073e9SAndroid Build Coastguard Worker def fn1(x, y, z): 10156*da0073e9SAndroid Build Coastguard Worker return fn_varargs(x) 10157*da0073e9SAndroid Build Coastguard Worker 10158*da0073e9SAndroid Build Coastguard Worker def fn2(x, y, z): 10159*da0073e9SAndroid Build Coastguard Worker return fn_varargs(x, y) 10160*da0073e9SAndroid Build Coastguard Worker 10161*da0073e9SAndroid Build Coastguard Worker def fn3(x, y, z): 10162*da0073e9SAndroid Build Coastguard Worker return fn_varargs(x, y, z) 10163*da0073e9SAndroid Build Coastguard Worker 10164*da0073e9SAndroid Build Coastguard Worker x, y, z = (torch.randn(2, 2) for _ in range(3)) 10165*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn1, (x, y, z), optimize=True) 10166*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn2, (x, y, z), optimize=True) 10167*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn3, (x, y, z), optimize=True) 10168*da0073e9SAndroid Build Coastguard Worker 10169*da0073e9SAndroid Build Coastguard Worker def test_type_annotation_py3(self): 10170*da0073e9SAndroid Build Coastguard Worker code = dedent(""" 10171*da0073e9SAndroid Build Coastguard Worker import torch 10172*da0073e9SAndroid Build Coastguard Worker from torch import Tensor 10173*da0073e9SAndroid Build Coastguard Worker from typing import Tuple 10174*da0073e9SAndroid Build Coastguard Worker 10175*da0073e9SAndroid Build Coastguard Worker def fn(x : torch.Tensor, y : Tensor, z) -> Tuple[Tensor, Tensor, Tensor]: 10176*da0073e9SAndroid Build Coastguard Worker return (x, y + z, z) 10177*da0073e9SAndroid Build Coastguard Worker """) 10178*da0073e9SAndroid Build Coastguard Worker 10179*da0073e9SAndroid Build Coastguard Worker with tempfile.TemporaryDirectory() as tmp_dir: 10180*da0073e9SAndroid Build Coastguard Worker script_path = os.path.join(tmp_dir, 'script.py') 10181*da0073e9SAndroid Build Coastguard Worker with open(script_path, 'w') as f: 10182*da0073e9SAndroid Build Coastguard Worker f.write(code) 10183*da0073e9SAndroid Build Coastguard Worker fn = get_fn('test_type_annotation_py3', script_path) 10184*da0073e9SAndroid Build Coastguard Worker fn = torch.jit.ignore(fn) 10185*da0073e9SAndroid Build Coastguard Worker 10186*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r"Expected a value of type 'Tensor' for argument" 10187*da0073e9SAndroid Build Coastguard Worker r" 'x' but instead found type 'Tuple\[Tensor,"): 10188*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 10189*da0073e9SAndroid Build Coastguard Worker def bad_fn(x): 10190*da0073e9SAndroid Build Coastguard Worker x, y = fn((x, x), x, x) 10191*da0073e9SAndroid Build Coastguard Worker return y 10192*da0073e9SAndroid Build Coastguard Worker 10193*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r"too many values .* need 2 but found 3"): 10194*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 10195*da0073e9SAndroid Build Coastguard Worker def bad_fn2(x): 10196*da0073e9SAndroid Build Coastguard Worker x, y = fn(x, x, x) 10197*da0073e9SAndroid Build Coastguard Worker return y 10198*da0073e9SAndroid Build Coastguard Worker 10199*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r"need 4 values .* found only 3"): 10200*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 10201*da0073e9SAndroid Build Coastguard Worker def bad_fn3(x): 10202*da0073e9SAndroid Build Coastguard Worker x, y, z, w = fn(x, x, x) 10203*da0073e9SAndroid Build Coastguard Worker return y 10204*da0073e9SAndroid Build Coastguard Worker 10205*da0073e9SAndroid Build Coastguard Worker def good_fn(x): 10206*da0073e9SAndroid Build Coastguard Worker y, z, w = fn(x, x, x) 10207*da0073e9SAndroid Build Coastguard Worker return y, z, w 10208*da0073e9SAndroid Build Coastguard Worker 10209*da0073e9SAndroid Build Coastguard Worker self.checkScript(good_fn, (torch.ones(2, 2),), optimize=True) 10210*da0073e9SAndroid Build Coastguard Worker 10211*da0073e9SAndroid Build Coastguard Worker def test_type_annotation_module(self): 10212*da0073e9SAndroid Build Coastguard Worker class BaseModule(torch.jit.ScriptModule): 10213*da0073e9SAndroid Build Coastguard Worker @torch.jit.ignore 10214*da0073e9SAndroid Build Coastguard Worker def foo(self, x): 10215*da0073e9SAndroid Build Coastguard Worker # type: (Tensor) -> Tensor 10216*da0073e9SAndroid Build Coastguard Worker return x + 1 10217*da0073e9SAndroid Build Coastguard Worker 10218*da0073e9SAndroid Build Coastguard Worker @torch.jit.ignore 10219*da0073e9SAndroid Build Coastguard Worker def bar(self, x, y): 10220*da0073e9SAndroid Build Coastguard Worker # type: (Tensor, Tensor) -> Tuple[Tensor, Tensor] 10221*da0073e9SAndroid Build Coastguard Worker return x + y, y 10222*da0073e9SAndroid Build Coastguard Worker 10223*da0073e9SAndroid Build Coastguard Worker @torch.jit.ignore 10224*da0073e9SAndroid Build Coastguard Worker def baz(self, x, y): 10225*da0073e9SAndroid Build Coastguard Worker return x 10226*da0073e9SAndroid Build Coastguard Worker 10227*da0073e9SAndroid Build Coastguard Worker class ModuleTooMany(BaseModule): 10228*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 10229*da0073e9SAndroid Build Coastguard Worker def method(self, x): 10230*da0073e9SAndroid Build Coastguard Worker return self.foo(x, x) 10231*da0073e9SAndroid Build Coastguard Worker 10232*da0073e9SAndroid Build Coastguard Worker class ModuleTooFew(BaseModule): 10233*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 10234*da0073e9SAndroid Build Coastguard Worker def method(self, x): 10235*da0073e9SAndroid Build Coastguard Worker return self.bar(x) 10236*da0073e9SAndroid Build Coastguard Worker 10237*da0073e9SAndroid Build Coastguard Worker class ModuleTooManyAssign(BaseModule): 10238*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 10239*da0073e9SAndroid Build Coastguard Worker def method(self, x): 10240*da0073e9SAndroid Build Coastguard Worker y, z, w = self.bar(x, x) 10241*da0073e9SAndroid Build Coastguard Worker return x 10242*da0073e9SAndroid Build Coastguard Worker 10243*da0073e9SAndroid Build Coastguard Worker class ModuleDefault(BaseModule): 10244*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 10245*da0073e9SAndroid Build Coastguard Worker def method(self, x): 10246*da0073e9SAndroid Build Coastguard Worker y = self.baz(x) 10247*da0073e9SAndroid Build Coastguard Worker return x 10248*da0073e9SAndroid Build Coastguard Worker 10249*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Expected at most 2 arguments but found 3"): 10250*da0073e9SAndroid Build Coastguard Worker ModuleTooMany() 10251*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Argument y not provided"): 10252*da0073e9SAndroid Build Coastguard Worker ModuleTooFew() 10253*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "need 3 values .* found only 2"): 10254*da0073e9SAndroid Build Coastguard Worker ModuleTooManyAssign() 10255*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Argument y not provided."): 10256*da0073e9SAndroid Build Coastguard Worker ModuleDefault() 10257*da0073e9SAndroid Build Coastguard Worker 10258*da0073e9SAndroid Build Coastguard Worker def test_type_inferred_from_empty_annotation(self): 10259*da0073e9SAndroid Build Coastguard Worker """ 10260*da0073e9SAndroid Build Coastguard Worker Test that the type inferred from an empty or missing annotation is Torch.Tensor wtih `inferred=true` 10261*da0073e9SAndroid Build Coastguard Worker """ 10262*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 10263*da0073e9SAndroid Build Coastguard Worker def fn(x): 10264*da0073e9SAndroid Build Coastguard Worker return x 10265*da0073e9SAndroid Build Coastguard Worker 10266*da0073e9SAndroid Build Coastguard Worker graph = fn.graph 10267*da0073e9SAndroid Build Coastguard Worker n = next(graph.inputs()) 10268*da0073e9SAndroid Build Coastguard Worker self.assertTrue(n.type() == torch._C.TensorType.getInferred()) 10269*da0073e9SAndroid Build Coastguard Worker 10270*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Inferred 'x' to be of type 'Tensor"): 10271*da0073e9SAndroid Build Coastguard Worker fn("1") 10272*da0073e9SAndroid Build Coastguard Worker 10273*da0073e9SAndroid Build Coastguard Worker def test_script_define_order(self): 10274*da0073e9SAndroid Build Coastguard Worker class M(torch.jit.ScriptModule): 10275*da0073e9SAndroid Build Coastguard Worker 10276*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 10277*da0073e9SAndroid Build Coastguard Worker def call_foo(self, input): 10278*da0073e9SAndroid Build Coastguard Worker return self.foo(input) 10279*da0073e9SAndroid Build Coastguard Worker 10280*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 10281*da0073e9SAndroid Build Coastguard Worker def foo(self, input): 10282*da0073e9SAndroid Build Coastguard Worker return input + 1 10283*da0073e9SAndroid Build Coastguard Worker m = M() 10284*da0073e9SAndroid Build Coastguard Worker self.assertEqual(2, m.call_foo(torch.ones((), dtype=torch.int64))) 10285*da0073e9SAndroid Build Coastguard Worker 10286*da0073e9SAndroid Build Coastguard Worker def test_script_define_order_recursive_fail(self): 10287*da0073e9SAndroid Build Coastguard Worker class M(torch.jit.ScriptModule): 10288*da0073e9SAndroid Build Coastguard Worker 10289*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 10290*da0073e9SAndroid Build Coastguard Worker def call_foo(self, input): 10291*da0073e9SAndroid Build Coastguard Worker return self.foo(input) 10292*da0073e9SAndroid Build Coastguard Worker 10293*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 10294*da0073e9SAndroid Build Coastguard Worker def foo(self, input): 10295*da0073e9SAndroid Build Coastguard Worker self.call_foo(input) 10296*da0073e9SAndroid Build Coastguard Worker 10297*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'called recursively'): 10298*da0073e9SAndroid Build Coastguard Worker M() 10299*da0073e9SAndroid Build Coastguard Worker 10300*da0073e9SAndroid Build Coastguard Worker def test_script_kwargs_fn_call(self): 10301*da0073e9SAndroid Build Coastguard Worker class M(torch.jit.ScriptModule): 10302*da0073e9SAndroid Build Coastguard Worker 10303*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 10304*da0073e9SAndroid Build Coastguard Worker def call_foo(self, input): 10305*da0073e9SAndroid Build Coastguard Worker return self.foo(input=input, bar=1) 10306*da0073e9SAndroid Build Coastguard Worker 10307*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 10308*da0073e9SAndroid Build Coastguard Worker def foo(self, bar, input): 10309*da0073e9SAndroid Build Coastguard Worker # type: (int, Tensor) -> Tensor 10310*da0073e9SAndroid Build Coastguard Worker return input + bar 10311*da0073e9SAndroid Build Coastguard Worker m = M() 10312*da0073e9SAndroid Build Coastguard Worker self.assertEqual(2, m.call_foo(torch.ones((), dtype=torch.int64))) 10313*da0073e9SAndroid Build Coastguard Worker 10314*da0073e9SAndroid Build Coastguard Worker def test_if_define(self): 10315*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 10316*da0073e9SAndroid Build Coastguard Worker def foo(a): 10317*da0073e9SAndroid Build Coastguard Worker if bool(a == 0): 10318*da0073e9SAndroid Build Coastguard Worker b = 1 10319*da0073e9SAndroid Build Coastguard Worker else: 10320*da0073e9SAndroid Build Coastguard Worker b = 0 10321*da0073e9SAndroid Build Coastguard Worker return b + 1 10322*da0073e9SAndroid Build Coastguard Worker 10323*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 10324*da0073e9SAndroid Build Coastguard Worker def foo2(a): 10325*da0073e9SAndroid Build Coastguard Worker b = 0 10326*da0073e9SAndroid Build Coastguard Worker if bool(a == 0): 10327*da0073e9SAndroid Build Coastguard Worker b = 1 10328*da0073e9SAndroid Build Coastguard Worker return b + 1 10329*da0073e9SAndroid Build Coastguard Worker 10330*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 10331*da0073e9SAndroid Build Coastguard Worker def foo3(a): 10332*da0073e9SAndroid Build Coastguard Worker b = 1 10333*da0073e9SAndroid Build Coastguard Worker if bool(a == 0): 10334*da0073e9SAndroid Build Coastguard Worker c = 4 10335*da0073e9SAndroid Build Coastguard Worker else: 10336*da0073e9SAndroid Build Coastguard Worker b = 0 10337*da0073e9SAndroid Build Coastguard Worker return b + 1 10338*da0073e9SAndroid Build Coastguard Worker 10339*da0073e9SAndroid Build Coastguard Worker a = torch.ones(1, dtype=torch.long) 10340*da0073e9SAndroid Build Coastguard Worker b = torch.zeros(1, dtype=torch.long) 10341*da0073e9SAndroid Build Coastguard Worker self.assertEqual(1, foo(a)) 10342*da0073e9SAndroid Build Coastguard Worker self.assertEqual(2, foo(b)) 10343*da0073e9SAndroid Build Coastguard Worker self.assertEqual(1, foo2(a)) 10344*da0073e9SAndroid Build Coastguard Worker self.assertEqual(2, foo2(b)) 10345*da0073e9SAndroid Build Coastguard Worker self.assertEqual(1, foo3(a)) 10346*da0073e9SAndroid Build Coastguard Worker self.assertEqual(2, foo3(b)) 10347*da0073e9SAndroid Build Coastguard Worker 10348*da0073e9SAndroid Build Coastguard Worker def test_script_module_export_submodule(self): 10349*da0073e9SAndroid Build Coastguard Worker class M1(torch.jit.ScriptModule): 10350*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 10351*da0073e9SAndroid Build Coastguard Worker super().__init__() 10352*da0073e9SAndroid Build Coastguard Worker self.weight = nn.Parameter(torch.randn(2)) 10353*da0073e9SAndroid Build Coastguard Worker 10354*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 10355*da0073e9SAndroid Build Coastguard Worker def forward(self, thing): 10356*da0073e9SAndroid Build Coastguard Worker return self.weight + thing 10357*da0073e9SAndroid Build Coastguard Worker 10358*da0073e9SAndroid Build Coastguard Worker class M2(torch.jit.ScriptModule): 10359*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 10360*da0073e9SAndroid Build Coastguard Worker super().__init__() 10361*da0073e9SAndroid Build Coastguard Worker # test submodule 10362*da0073e9SAndroid Build Coastguard Worker self.sub = M1() 10363*da0073e9SAndroid Build Coastguard Worker self.weight = nn.Parameter(torch.randn(2, 3)) 10364*da0073e9SAndroid Build Coastguard Worker self.bias = nn.Parameter(torch.randn(2)) 10365*da0073e9SAndroid Build Coastguard Worker self.define(""" 10366*da0073e9SAndroid Build Coastguard Worker def hi(self, a): 10367*da0073e9SAndroid Build Coastguard Worker return self.weight.mm(a) 10368*da0073e9SAndroid Build Coastguard Worker """) 10369*da0073e9SAndroid Build Coastguard Worker 10370*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 10371*da0073e9SAndroid Build Coastguard Worker def doit(self, input): 10372*da0073e9SAndroid Build Coastguard Worker return self.weight.mm(input) 10373*da0073e9SAndroid Build Coastguard Worker 10374*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 10375*da0073e9SAndroid Build Coastguard Worker def doit2(self, input): 10376*da0073e9SAndroid Build Coastguard Worker return self.weight.mm(input) 10377*da0073e9SAndroid Build Coastguard Worker 10378*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 10379*da0073e9SAndroid Build Coastguard Worker def doit3(self, input): 10380*da0073e9SAndroid Build Coastguard Worker return input + torch.ones([1], dtype=torch.double) 10381*da0073e9SAndroid Build Coastguard Worker 10382*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 10383*da0073e9SAndroid Build Coastguard Worker def forward(self, input): 10384*da0073e9SAndroid Build Coastguard Worker a = self.doit(input) 10385*da0073e9SAndroid Build Coastguard Worker b = self.doit2(input) 10386*da0073e9SAndroid Build Coastguard Worker c = self.hi(input) 10387*da0073e9SAndroid Build Coastguard Worker return a + b + self.bias + c 10388*da0073e9SAndroid Build Coastguard Worker 10389*da0073e9SAndroid Build Coastguard Worker with torch.jit.optimized_execution(False): 10390*da0073e9SAndroid Build Coastguard Worker m_orig = M2() 10391*da0073e9SAndroid Build Coastguard Worker m_import = self.getExportImportCopy(m_orig) 10392*da0073e9SAndroid Build Coastguard Worker 10393*da0073e9SAndroid Build Coastguard Worker input = torch.randn(3, 2) 10394*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m_orig.doit(input), m_import.doit(input)) 10395*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m_orig.hi(input), m_import.hi(input)) 10396*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m_orig.doit3(input), m_import.doit3(input)) 10397*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m_orig.forward(input), m_import.forward(input)) 10398*da0073e9SAndroid Build Coastguard Worker 10399*da0073e9SAndroid Build Coastguard Worker @slowTest 10400*da0073e9SAndroid Build Coastguard Worker def test_compile_module_with_constant(self): 10401*da0073e9SAndroid Build Coastguard Worker class Double(nn.Module): 10402*da0073e9SAndroid Build Coastguard Worker def __init__(self, downsample=None): 10403*da0073e9SAndroid Build Coastguard Worker super().__init__() 10404*da0073e9SAndroid Build Coastguard Worker 10405*da0073e9SAndroid Build Coastguard Worker def forward(self, input): 10406*da0073e9SAndroid Build Coastguard Worker return input * 2 10407*da0073e9SAndroid Build Coastguard Worker 10408*da0073e9SAndroid Build Coastguard Worker class Mod(nn.Module): 10409*da0073e9SAndroid Build Coastguard Worker __constants__ = ['downsample'] 10410*da0073e9SAndroid Build Coastguard Worker 10411*da0073e9SAndroid Build Coastguard Worker def __init__(self, downsample=None): 10412*da0073e9SAndroid Build Coastguard Worker super().__init__() 10413*da0073e9SAndroid Build Coastguard Worker self.downsample = downsample 10414*da0073e9SAndroid Build Coastguard Worker 10415*da0073e9SAndroid Build Coastguard Worker def forward(self, input): 10416*da0073e9SAndroid Build Coastguard Worker if self.downsample is not None: 10417*da0073e9SAndroid Build Coastguard Worker return self.downsample(input) 10418*da0073e9SAndroid Build Coastguard Worker return input 10419*da0073e9SAndroid Build Coastguard Worker 10420*da0073e9SAndroid Build Coastguard Worker none_mod = torch.jit.script(Mod(None)) 10421*da0073e9SAndroid Build Coastguard Worker double_mod = torch.jit.script(Mod(Double())) 10422*da0073e9SAndroid Build Coastguard Worker self.assertEqual(none_mod(torch.tensor(1)), torch.tensor(1)) 10423*da0073e9SAndroid Build Coastguard Worker self.assertEqual(double_mod(torch.tensor(1)), torch.tensor(1) * 2) 10424*da0073e9SAndroid Build Coastguard Worker 10425*da0073e9SAndroid Build Coastguard Worker def test_device_kwarg(self): 10426*da0073e9SAndroid Build Coastguard Worker from torch import device 10427*da0073e9SAndroid Build Coastguard Worker 10428*da0073e9SAndroid Build Coastguard Worker def f(): 10429*da0073e9SAndroid Build Coastguard Worker return device(type='cuda'), torch.device(type='cpu') 10430*da0073e9SAndroid Build Coastguard Worker self.checkScript(f, ()) 10431*da0073e9SAndroid Build Coastguard Worker 10432*da0073e9SAndroid Build Coastguard Worker def test_script_module_export_tensor_type(self): 10433*da0073e9SAndroid Build Coastguard Worker class M(torch.jit.ScriptModule): 10434*da0073e9SAndroid Build Coastguard Worker def __init__(self, type): 10435*da0073e9SAndroid Build Coastguard Worker super().__init__() 10436*da0073e9SAndroid Build Coastguard Worker self.param = torch.nn.Parameter(torch.zeros((5, 5), dtype=type).random_()) 10437*da0073e9SAndroid Build Coastguard Worker 10438*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 10439*da0073e9SAndroid Build Coastguard Worker def foo(self): 10440*da0073e9SAndroid Build Coastguard Worker return self.param 10441*da0073e9SAndroid Build Coastguard Worker 10442*da0073e9SAndroid Build Coastguard Worker with torch.jit.optimized_execution(False): 10443*da0073e9SAndroid Build Coastguard Worker for type in [torch.float, torch.double]: 10444*da0073e9SAndroid Build Coastguard Worker m_orig = M(type) 10445*da0073e9SAndroid Build Coastguard Worker m_import = self.getExportImportCopy(m_orig) 10446*da0073e9SAndroid Build Coastguard Worker # check to make sure the storage wasn't resized 10447*da0073e9SAndroid Build Coastguard Worker self.assertTrue(m_orig.param.storage().size() == 25) 10448*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m_orig.foo(), m_import.foo()) 10449*da0073e9SAndroid Build Coastguard Worker self.assertTrue(m_orig.foo().dtype == m_import.foo().dtype) 10450*da0073e9SAndroid Build Coastguard Worker 10451*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not RUN_CUDA, "testing cuda tensors require CUDA") 10452*da0073e9SAndroid Build Coastguard Worker def test_script_module_export_tensor_cuda(self): 10453*da0073e9SAndroid Build Coastguard Worker class M(torch.jit.ScriptModule): 10454*da0073e9SAndroid Build Coastguard Worker 10455*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 10456*da0073e9SAndroid Build Coastguard Worker super().__init__() 10457*da0073e9SAndroid Build Coastguard Worker self.param = torch.nn.Parameter(torch.zeros((5, 5), device='cuda:0').random_()) 10458*da0073e9SAndroid Build Coastguard Worker 10459*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 10460*da0073e9SAndroid Build Coastguard Worker def foo(self): 10461*da0073e9SAndroid Build Coastguard Worker return self.param 10462*da0073e9SAndroid Build Coastguard Worker 10463*da0073e9SAndroid Build Coastguard Worker m_orig = M() 10464*da0073e9SAndroid Build Coastguard Worker m_import = self.getExportImportCopy(m_orig) 10465*da0073e9SAndroid Build Coastguard Worker # check to make sure the storage wasn't resized 10466*da0073e9SAndroid Build Coastguard Worker self.assertTrue(m_orig.param.storage().size() == 25) 10467*da0073e9SAndroid Build Coastguard Worker self.assertTrue(m_import.foo().device == torch.device('cuda:0')) 10468*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m_orig.foo(), m_import.foo()) 10469*da0073e9SAndroid Build Coastguard Worker self.assertTrue(m_orig.foo().dtype == m_import.foo().dtype) 10470*da0073e9SAndroid Build Coastguard Worker 10471*da0073e9SAndroid Build Coastguard Worker def test_script_module_export_blocks(self): 10472*da0073e9SAndroid Build Coastguard Worker class M(torch.jit.ScriptModule): 10473*da0073e9SAndroid Build Coastguard Worker def __init__(self, n, m): 10474*da0073e9SAndroid Build Coastguard Worker super().__init__() 10475*da0073e9SAndroid Build Coastguard Worker self.weight = torch.nn.Parameter(torch.rand(n, m)) 10476*da0073e9SAndroid Build Coastguard Worker 10477*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 10478*da0073e9SAndroid Build Coastguard Worker def forward(self, input): 10479*da0073e9SAndroid Build Coastguard Worker if bool(input.sum() > 0): 10480*da0073e9SAndroid Build Coastguard Worker output = self.weight.mv(input) 10481*da0073e9SAndroid Build Coastguard Worker else: 10482*da0073e9SAndroid Build Coastguard Worker output = self.weight + input 10483*da0073e9SAndroid Build Coastguard Worker return output 10484*da0073e9SAndroid Build Coastguard Worker 10485*da0073e9SAndroid Build Coastguard Worker m_orig = M(200, 200) 10486*da0073e9SAndroid Build Coastguard Worker m_import = self.getExportImportCopy(m_orig) 10487*da0073e9SAndroid Build Coastguard Worker 10488*da0073e9SAndroid Build Coastguard Worker t = torch.rand(200) 10489*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m_orig(t), m_import(t)) 10490*da0073e9SAndroid Build Coastguard Worker 10491*da0073e9SAndroid Build Coastguard Worker def test_script_module_export_shared_storage(self): 10492*da0073e9SAndroid Build Coastguard Worker class M(torch.jit.ScriptModule): 10493*da0073e9SAndroid Build Coastguard Worker 10494*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 10495*da0073e9SAndroid Build Coastguard Worker super().__init__() 10496*da0073e9SAndroid Build Coastguard Worker self.param1 = torch.nn.Parameter(torch.rand(5, 5)) 10497*da0073e9SAndroid Build Coastguard Worker self.param2 = torch.nn.Parameter(self.param1[3]) 10498*da0073e9SAndroid Build Coastguard Worker self.param3 = torch.nn.Parameter(torch.rand(5, 5)) 10499*da0073e9SAndroid Build Coastguard Worker self.param4 = torch.nn.Parameter(torch.rand(11, 5)[1:6]) 10500*da0073e9SAndroid Build Coastguard Worker 10501*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 10502*da0073e9SAndroid Build Coastguard Worker def foo(self): 10503*da0073e9SAndroid Build Coastguard Worker return self.param1 + self.param2 + self.param3 + self.param4 10504*da0073e9SAndroid Build Coastguard Worker 10505*da0073e9SAndroid Build Coastguard Worker with torch.jit.optimized_execution(False): 10506*da0073e9SAndroid Build Coastguard Worker m_orig = M() 10507*da0073e9SAndroid Build Coastguard Worker m_import = self.getExportImportCopy(m_orig) 10508*da0073e9SAndroid Build Coastguard Worker 10509*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m_orig.foo(), m_import.foo()) 10510*da0073e9SAndroid Build Coastguard Worker 10511*da0073e9SAndroid Build Coastguard Worker self.assertTrue(m_import.param1.storage().data_ptr() == m_import.param2.storage().data_ptr()) 10512*da0073e9SAndroid Build Coastguard Worker self.assertTrue(m_import.param1.storage().data_ptr() != m_import.param3.storage().data_ptr()) 10513*da0073e9SAndroid Build Coastguard Worker 10514*da0073e9SAndroid Build Coastguard Worker def test_sequential_intermediary_types(self): 10515*da0073e9SAndroid Build Coastguard Worker class A(torch.nn.Module): 10516*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 10517*da0073e9SAndroid Build Coastguard Worker return x + 3 10518*da0073e9SAndroid Build Coastguard Worker 10519*da0073e9SAndroid Build Coastguard Worker class B(torch.nn.Module): 10520*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 10521*da0073e9SAndroid Build Coastguard Worker return {"1": x} 10522*da0073e9SAndroid Build Coastguard Worker 10523*da0073e9SAndroid Build Coastguard Worker class C(torch.nn.Module): 10524*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 10525*da0073e9SAndroid Build Coastguard Worker super().__init__() 10526*da0073e9SAndroid Build Coastguard Worker self.foo = torch.nn.Sequential(A(), B()) 10527*da0073e9SAndroid Build Coastguard Worker 10528*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 10529*da0073e9SAndroid Build Coastguard Worker return self.foo(x) 10530*da0073e9SAndroid Build Coastguard Worker 10531*da0073e9SAndroid Build Coastguard Worker self.checkModule(C(), (torch.tensor(1),)) 10532*da0073e9SAndroid Build Coastguard Worker 10533*da0073e9SAndroid Build Coastguard Worker def test_ellipsis_const_mid(self): 10534*da0073e9SAndroid Build Coastguard Worker def ellipsize(x): 10535*da0073e9SAndroid Build Coastguard Worker # type: (Tensor) -> List[int] 10536*da0073e9SAndroid Build Coastguard Worker return x[2, Ellipsis, 0:4, 4:8].size() 10537*da0073e9SAndroid Build Coastguard Worker 10538*da0073e9SAndroid Build Coastguard Worker dummy = torch.zeros(8, 8, 8, 8, 8) 10539*da0073e9SAndroid Build Coastguard Worker self.checkScript(ellipsize, (dummy,), optimize=True) 10540*da0073e9SAndroid Build Coastguard Worker 10541*da0073e9SAndroid Build Coastguard Worker def test_ellipsis_const_mid_select(self): 10542*da0073e9SAndroid Build Coastguard Worker def ellipsize(x): 10543*da0073e9SAndroid Build Coastguard Worker # type: (Tensor) -> List[int] 10544*da0073e9SAndroid Build Coastguard Worker return x[2, Ellipsis, 4, 4, 4:8, 2].size() 10545*da0073e9SAndroid Build Coastguard Worker 10546*da0073e9SAndroid Build Coastguard Worker dummy = torch.zeros(8, 8, 8, 8, 8, 8, 8) 10547*da0073e9SAndroid Build Coastguard Worker self.checkScript(ellipsize, (dummy,), optimize=True) 10548*da0073e9SAndroid Build Coastguard Worker 10549*da0073e9SAndroid Build Coastguard Worker def test_ellipsis_const_start(self): 10550*da0073e9SAndroid Build Coastguard Worker def ellipsize(x): 10551*da0073e9SAndroid Build Coastguard Worker # type: (Tensor) -> List[int] 10552*da0073e9SAndroid Build Coastguard Worker return x[Ellipsis, 0:4, 4:8].size() 10553*da0073e9SAndroid Build Coastguard Worker dummy = torch.zeros(8, 8, 8, 8, 8) 10554*da0073e9SAndroid Build Coastguard Worker self.checkScript(ellipsize, (dummy,), optimize=True) 10555*da0073e9SAndroid Build Coastguard Worker 10556*da0073e9SAndroid Build Coastguard Worker def test_ellipsis_const_end(self): 10557*da0073e9SAndroid Build Coastguard Worker def ellipsize(x): 10558*da0073e9SAndroid Build Coastguard Worker # type: (Tensor) -> List[int] 10559*da0073e9SAndroid Build Coastguard Worker return x[0:4, 2, Ellipsis].size() 10560*da0073e9SAndroid Build Coastguard Worker dummy = torch.zeros(8, 8, 8, 8, 8) 10561*da0073e9SAndroid Build Coastguard Worker self.checkScript(ellipsize, (dummy,), optimize=True) 10562*da0073e9SAndroid Build Coastguard Worker 10563*da0073e9SAndroid Build Coastguard Worker def test_ellipsis_mid(self): 10564*da0073e9SAndroid Build Coastguard Worker def ellipsize(x): 10565*da0073e9SAndroid Build Coastguard Worker # type: (Tensor) -> List[int] 10566*da0073e9SAndroid Build Coastguard Worker return x[2, ..., 0:4, 4:8].size() 10567*da0073e9SAndroid Build Coastguard Worker 10568*da0073e9SAndroid Build Coastguard Worker dummy = torch.zeros(8, 8, 8, 8, 8) 10569*da0073e9SAndroid Build Coastguard Worker self.checkScript(ellipsize, (dummy,), optimize=True) 10570*da0073e9SAndroid Build Coastguard Worker 10571*da0073e9SAndroid Build Coastguard Worker def test_ellipsis_mid_select(self): 10572*da0073e9SAndroid Build Coastguard Worker def ellipsize(x): 10573*da0073e9SAndroid Build Coastguard Worker # type: (Tensor) -> List[int] 10574*da0073e9SAndroid Build Coastguard Worker return x[2, ..., 4, 4, 4:8, 2].size() 10575*da0073e9SAndroid Build Coastguard Worker 10576*da0073e9SAndroid Build Coastguard Worker dummy = torch.zeros(8, 8, 8, 8, 8, 8, 8) 10577*da0073e9SAndroid Build Coastguard Worker self.checkScript(ellipsize, (dummy,), optimize=True) 10578*da0073e9SAndroid Build Coastguard Worker 10579*da0073e9SAndroid Build Coastguard Worker def test_ellipsis_start(self): 10580*da0073e9SAndroid Build Coastguard Worker def ellipsize(x): 10581*da0073e9SAndroid Build Coastguard Worker # type: (Tensor) -> List[int] 10582*da0073e9SAndroid Build Coastguard Worker return x[..., 0:4, 4:8].size() 10583*da0073e9SAndroid Build Coastguard Worker dummy = torch.zeros(8, 8, 8, 8, 8) 10584*da0073e9SAndroid Build Coastguard Worker self.checkScript(ellipsize, (dummy,), optimize=True) 10585*da0073e9SAndroid Build Coastguard Worker 10586*da0073e9SAndroid Build Coastguard Worker def test_ellipsis_end(self): 10587*da0073e9SAndroid Build Coastguard Worker def ellipsize(x): 10588*da0073e9SAndroid Build Coastguard Worker # type: (Tensor) -> List[int] 10589*da0073e9SAndroid Build Coastguard Worker return x[0:4, 2, ...].size() 10590*da0073e9SAndroid Build Coastguard Worker dummy = torch.zeros(8, 8, 8, 8, 8) 10591*da0073e9SAndroid Build Coastguard Worker self.checkScript(ellipsize, (dummy,), optimize=True) 10592*da0073e9SAndroid Build Coastguard Worker 10593*da0073e9SAndroid Build Coastguard Worker def test_torch_manual_seed(self): 10594*da0073e9SAndroid Build Coastguard Worker with freeze_rng_state(): 10595*da0073e9SAndroid Build Coastguard Worker def test(): 10596*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(2) 10597*da0073e9SAndroid Build Coastguard Worker return torch.rand(1) 10598*da0073e9SAndroid Build Coastguard Worker 10599*da0073e9SAndroid Build Coastguard Worker script = torch.jit.script(test) 10600*da0073e9SAndroid Build Coastguard Worker self.assertEqual(test(), script()) 10601*da0073e9SAndroid Build Coastguard Worker graph = script.graph_for() 10602*da0073e9SAndroid Build Coastguard Worker FileCheck().check("aten::manual_seed").run(graph) 10603*da0073e9SAndroid Build Coastguard Worker 10604*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("Not a TorchDynamo suitable test") 10605*da0073e9SAndroid Build Coastguard Worker def test_index_select_shape_prop(self): 10606*da0073e9SAndroid Build Coastguard Worker 10607*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 10608*da0073e9SAndroid Build Coastguard Worker def foo(x, y): 10609*da0073e9SAndroid Build Coastguard Worker return torch.index_select(x, index=y, dim=1) 10610*da0073e9SAndroid Build Coastguard Worker 10611*da0073e9SAndroid Build Coastguard Worker a = torch.zeros(2, 2) 10612*da0073e9SAndroid Build Coastguard Worker b = torch.zeros(4, dtype=torch.long) 10613*da0073e9SAndroid Build Coastguard Worker torch._C._jit_pass_complete_shape_analysis(foo.graph, (a, b), False) 10614*da0073e9SAndroid Build Coastguard Worker FileCheck().check("Float(2, 4, strides=[4, 1], requires_grad=0, device=cpu)").run(str(foo.graph)) 10615*da0073e9SAndroid Build Coastguard Worker 10616*da0073e9SAndroid Build Coastguard Worker def test_shape_analysis_loop(self): 10617*da0073e9SAndroid Build Coastguard Worker def foo(a, b, x): 10618*da0073e9SAndroid Build Coastguard Worker c = a 10619*da0073e9SAndroid Build Coastguard Worker # on the first iteration of the loop it appears that 10620*da0073e9SAndroid Build Coastguard Worker # c should have a expand to the size of b 10621*da0073e9SAndroid Build Coastguard Worker # but on the second+ iterations, there is no broadcast and the 10622*da0073e9SAndroid Build Coastguard Worker # sizes are different. 10623*da0073e9SAndroid Build Coastguard Worker # previously this would cause the compiler to (1) enter an infinite 10624*da0073e9SAndroid Build Coastguard Worker # loop trying to compute the shape, and (2) insert invalid 10625*da0073e9SAndroid Build Coastguard Worker # broadcasts. 10626*da0073e9SAndroid Build Coastguard Worker # this test ensure we don't regress on these issues 10627*da0073e9SAndroid Build Coastguard Worker for _ in range(2): 10628*da0073e9SAndroid Build Coastguard Worker a = c + b 10629*da0073e9SAndroid Build Coastguard Worker c = x 10630*da0073e9SAndroid Build Coastguard Worker b = x 10631*da0073e9SAndroid Build Coastguard Worker return a 10632*da0073e9SAndroid Build Coastguard Worker 10633*da0073e9SAndroid Build Coastguard Worker self.checkScript(foo, (torch.zeros(1), torch.zeros(4), torch.zeros(5)), optimize=False) 10634*da0073e9SAndroid Build Coastguard Worker 10635*da0073e9SAndroid Build Coastguard Worker def test_intlist_args(self): 10636*da0073e9SAndroid Build Coastguard Worker def func_1(x): 10637*da0073e9SAndroid Build Coastguard Worker return torch.nn.functional.adaptive_avg_pool1d(x, 1) 10638*da0073e9SAndroid Build Coastguard Worker 10639*da0073e9SAndroid Build Coastguard Worker def func_2(x): 10640*da0073e9SAndroid Build Coastguard Worker return torch.nn.functional.adaptive_avg_pool1d(x, output_size=1) 10641*da0073e9SAndroid Build Coastguard Worker 10642*da0073e9SAndroid Build Coastguard Worker def func_3(x): 10643*da0073e9SAndroid Build Coastguard Worker return torch.nn.functional.adaptive_avg_pool1d(x, output_size=[1]) 10644*da0073e9SAndroid Build Coastguard Worker 10645*da0073e9SAndroid Build Coastguard Worker x = torch.randn(8, 8, 8) 10646*da0073e9SAndroid Build Coastguard Worker self.checkScript(func_1, [x], optimize=True) 10647*da0073e9SAndroid Build Coastguard Worker self.checkScript(func_2, [x], optimize=True) 10648*da0073e9SAndroid Build Coastguard Worker self.checkScript(func_3, [x], optimize=True) 10649*da0073e9SAndroid Build Coastguard Worker 10650*da0073e9SAndroid Build Coastguard Worker def test_wrong_implicit_expand(self): 10651*da0073e9SAndroid Build Coastguard Worker 10652*da0073e9SAndroid Build Coastguard Worker @_trace(torch.zeros(3), torch.zeros(1)) 10653*da0073e9SAndroid Build Coastguard Worker def foo(a, b): 10654*da0073e9SAndroid Build Coastguard Worker return a + b 10655*da0073e9SAndroid Build Coastguard Worker 10656*da0073e9SAndroid Build Coastguard Worker a = torch.rand(4) 10657*da0073e9SAndroid Build Coastguard Worker b = torch.rand(4) 10658*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a + b, foo(a, b)) 10659*da0073e9SAndroid Build Coastguard Worker 10660*da0073e9SAndroid Build Coastguard Worker def test_builtin_args_fails(self): 10661*da0073e9SAndroid Build Coastguard Worker 10662*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'Argument self not provided'): 10663*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 10664*da0073e9SAndroid Build Coastguard Worker def f1(a): 10665*da0073e9SAndroid Build Coastguard Worker torch.sum(foo=4) 10666*da0073e9SAndroid Build Coastguard Worker 10667*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'specified twice'): 10668*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 10669*da0073e9SAndroid Build Coastguard Worker def f2(a): 10670*da0073e9SAndroid Build Coastguard Worker torch.sum(a, self=a) 10671*da0073e9SAndroid Build Coastguard Worker 10672*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'not provided'): 10673*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 10674*da0073e9SAndroid Build Coastguard Worker def f3(a): 10675*da0073e9SAndroid Build Coastguard Worker torch.sum(dim=4) 10676*da0073e9SAndroid Build Coastguard Worker 10677*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'for argument \'tensors\' but instead found type \'Tensor'): 10678*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 10679*da0073e9SAndroid Build Coastguard Worker def f4(a): 10680*da0073e9SAndroid Build Coastguard Worker torch.cat(a) 10681*da0073e9SAndroid Build Coastguard Worker 10682*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r'argument \'tensors\' but instead found type \'List\[int\]'): 10683*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 10684*da0073e9SAndroid Build Coastguard Worker def f5(a): 10685*da0073e9SAndroid Build Coastguard Worker torch.cat([3]) 10686*da0073e9SAndroid Build Coastguard Worker 10687*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r'Expected a value of' 10688*da0073e9SAndroid Build Coastguard Worker r' type \'List\[int\]\' for argument' 10689*da0073e9SAndroid Build Coastguard Worker r' \'size\' but instead found type ' 10690*da0073e9SAndroid Build Coastguard Worker r'\'List\[Union\[List\[int\], int\]\]'): 10691*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 10692*da0073e9SAndroid Build Coastguard Worker def f6(a): 10693*da0073e9SAndroid Build Coastguard Worker a.expand(size=[3, [4]]) 10694*da0073e9SAndroid Build Coastguard Worker 10695*da0073e9SAndroid Build Coastguard Worker def test_builtin_args(self): 10696*da0073e9SAndroid Build Coastguard Worker 10697*da0073e9SAndroid Build Coastguard Worker def t0(a): 10698*da0073e9SAndroid Build Coastguard Worker # default arg dim 10699*da0073e9SAndroid Build Coastguard Worker return torch.cat([a, a]) 10700*da0073e9SAndroid Build Coastguard Worker 10701*da0073e9SAndroid Build Coastguard Worker self.checkScript(t0, (torch.zeros(1, 1),)) 10702*da0073e9SAndroid Build Coastguard Worker 10703*da0073e9SAndroid Build Coastguard Worker def t1(a): 10704*da0073e9SAndroid Build Coastguard Worker # keywords out of order 10705*da0073e9SAndroid Build Coastguard Worker return torch.cat(dim=1, tensors=[a, a]) 10706*da0073e9SAndroid Build Coastguard Worker 10707*da0073e9SAndroid Build Coastguard Worker self.checkScript(t1, (torch.zeros(1, 1, 2),)) 10708*da0073e9SAndroid Build Coastguard Worker 10709*da0073e9SAndroid Build Coastguard Worker def t2(a): 10710*da0073e9SAndroid Build Coastguard Worker # mix const/non-const attributes 10711*da0073e9SAndroid Build Coastguard Worker if 1 == 1: 10712*da0073e9SAndroid Build Coastguard Worker b = 1 10713*da0073e9SAndroid Build Coastguard Worker else: 10714*da0073e9SAndroid Build Coastguard Worker b = 0 10715*da0073e9SAndroid Build Coastguard Worker return torch.sum(a, dim=b, keepdim=False) 10716*da0073e9SAndroid Build Coastguard Worker 10717*da0073e9SAndroid Build Coastguard Worker self.checkScript(t2, (torch.zeros(1, 1, 2),)) 10718*da0073e9SAndroid Build Coastguard Worker 10719*da0073e9SAndroid Build Coastguard Worker def test_parser_type_annotations(self): 10720*da0073e9SAndroid Build Coastguard Worker cu = torch.jit.CompilationUnit(''' 10721*da0073e9SAndroid Build Coastguard Worker def foo(x : Tensor, y : Tuple[Tuple[Tensor, Tensor], Tensor]) -> Tuple[Tensor, Tensor]: 10722*da0073e9SAndroid Build Coastguard Worker return x, x 10723*da0073e9SAndroid Build Coastguard Worker ''') 10724*da0073e9SAndroid Build Coastguard Worker 10725*da0073e9SAndroid Build Coastguard Worker self.assertExpected(str(cu.foo.schema)) 10726*da0073e9SAndroid Build Coastguard Worker 10727*da0073e9SAndroid Build Coastguard Worker def test_parser_type_annotations_comment(self): 10728*da0073e9SAndroid Build Coastguard Worker cu = torch.jit.CompilationUnit(''' 10729*da0073e9SAndroid Build Coastguard Worker def foo(x, y): 10730*da0073e9SAndroid Build Coastguard Worker # type: (Tensor, Tuple[Tuple[Tensor, Tensor], Tensor]) -> Tuple[Tensor, Tensor] 10731*da0073e9SAndroid Build Coastguard Worker return x, x 10732*da0073e9SAndroid Build Coastguard Worker ''') 10733*da0073e9SAndroid Build Coastguard Worker 10734*da0073e9SAndroid Build Coastguard Worker self.assertExpected(str(cu.foo.schema)) 10735*da0073e9SAndroid Build Coastguard Worker 10736*da0073e9SAndroid Build Coastguard Worker def test_parser_type_annotations_unknown_type(self): 10737*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Unknown type name 'Foo'"): 10738*da0073e9SAndroid Build Coastguard Worker cu = torch.jit.CompilationUnit(''' 10739*da0073e9SAndroid Build Coastguard Worker def foo(x : Tensor, y : Tuple[Tuple[Foo, Tensor], Tensor]) -> Tuple[Tensor, Tensor]: 10740*da0073e9SAndroid Build Coastguard Worker return x, x 10741*da0073e9SAndroid Build Coastguard Worker ''') 10742*da0073e9SAndroid Build Coastguard Worker 10743*da0073e9SAndroid Build Coastguard Worker def test_parser_type_annotations_subscript_non_ident(self): 10744*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r'Subscripted type must be a type identifier'): 10745*da0073e9SAndroid Build Coastguard Worker cu = torch.jit.CompilationUnit(''' 10746*da0073e9SAndroid Build Coastguard Worker def foo(x : Tensor, y : Tuple[Tensor, Tensor][Tensor]) -> Tuple[Tensor, Tensor]: 10747*da0073e9SAndroid Build Coastguard Worker return x, x 10748*da0073e9SAndroid Build Coastguard Worker ''') 10749*da0073e9SAndroid Build Coastguard Worker 10750*da0073e9SAndroid Build Coastguard Worker def test_parser_type_annotations_subscript_tensor(self): 10751*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r'Unknown type constructor Tensor'): 10752*da0073e9SAndroid Build Coastguard Worker cu = torch.jit.CompilationUnit(''' 10753*da0073e9SAndroid Build Coastguard Worker def foo(x : Tensor, y : Tensor[Tensor, Tensor]) -> Tuple[Tensor, Tensor]: 10754*da0073e9SAndroid Build Coastguard Worker return x, x 10755*da0073e9SAndroid Build Coastguard Worker ''') 10756*da0073e9SAndroid Build Coastguard Worker 10757*da0073e9SAndroid Build Coastguard Worker def test_parser_type_annotations_incompatible_expression(self): 10758*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r'Expression of type \+ cannot be used in a type expression'): 10759*da0073e9SAndroid Build Coastguard Worker cu = torch.jit.CompilationUnit(''' 10760*da0073e9SAndroid Build Coastguard Worker def foo(x : Tensor, y : Tuple[3 + 4, Tensor]) -> Tuple[Tensor, Tensor]: 10761*da0073e9SAndroid Build Coastguard Worker return x, x 10762*da0073e9SAndroid Build Coastguard Worker ''') 10763*da0073e9SAndroid Build Coastguard Worker 10764*da0073e9SAndroid Build Coastguard Worker def test_gather_dynamic_index(self): 10765*da0073e9SAndroid Build Coastguard Worker def t(x): 10766*da0073e9SAndroid Build Coastguard Worker gather1 = x[0] 10767*da0073e9SAndroid Build Coastguard Worker idx = 0 + 1 10768*da0073e9SAndroid Build Coastguard Worker gather2 = x[idx] 10769*da0073e9SAndroid Build Coastguard Worker return gather1 + gather2 10770*da0073e9SAndroid Build Coastguard Worker 10771*da0073e9SAndroid Build Coastguard Worker self.checkScript(t, (torch.zeros(3, 2, 3),)) 10772*da0073e9SAndroid Build Coastguard Worker 10773*da0073e9SAndroid Build Coastguard Worker def test_torch_ignore_conversion_to_none(self): 10774*da0073e9SAndroid Build Coastguard Worker class A(torch.nn.Module): 10775*da0073e9SAndroid Build Coastguard Worker @torch.jit.ignore 10776*da0073e9SAndroid Build Coastguard Worker def ignored(self, a: int) -> None: 10777*da0073e9SAndroid Build Coastguard Worker l: int = len([2 for i in range(a) if i > 2]) 10778*da0073e9SAndroid Build Coastguard Worker return 10779*da0073e9SAndroid Build Coastguard Worker 10780*da0073e9SAndroid Build Coastguard Worker def forward(self) -> int: 10781*da0073e9SAndroid Build Coastguard Worker a: int = 4 10782*da0073e9SAndroid Build Coastguard Worker b: int = 5 10783*da0073e9SAndroid Build Coastguard Worker self.ignored(a) 10784*da0073e9SAndroid Build Coastguard Worker return a + b 10785*da0073e9SAndroid Build Coastguard Worker 10786*da0073e9SAndroid Build Coastguard Worker class B(torch.nn.Module): 10787*da0073e9SAndroid Build Coastguard Worker @torch.jit.ignore 10788*da0073e9SAndroid Build Coastguard Worker def ignored(self, a: int): 10789*da0073e9SAndroid Build Coastguard Worker l: int = len([2 for i in range(a) if i > 2]) 10790*da0073e9SAndroid Build Coastguard Worker return 10791*da0073e9SAndroid Build Coastguard Worker 10792*da0073e9SAndroid Build Coastguard Worker def forward(self) -> int: 10793*da0073e9SAndroid Build Coastguard Worker a: int = 4 10794*da0073e9SAndroid Build Coastguard Worker b: int = 5 10795*da0073e9SAndroid Build Coastguard Worker self.ignored(a) 10796*da0073e9SAndroid Build Coastguard Worker return a + b 10797*da0073e9SAndroid Build Coastguard Worker 10798*da0073e9SAndroid Build Coastguard Worker modelA = torch.jit.script(A()) 10799*da0073e9SAndroid Build Coastguard Worker self.assertEqual(modelA(), 9) 10800*da0073e9SAndroid Build Coastguard Worker 10801*da0073e9SAndroid Build Coastguard Worker modelB = torch.jit.script(B()) 10802*da0073e9SAndroid Build Coastguard Worker self.assertEqual(modelB(), 9) 10803*da0073e9SAndroid Build Coastguard Worker 10804*da0073e9SAndroid Build Coastguard Worker def test_addmm_grad(self): 10805*da0073e9SAndroid Build Coastguard Worker """ This test checks several things: 10806*da0073e9SAndroid Build Coastguard Worker 1. An expand node was inserted before the addmm operating on the 10807*da0073e9SAndroid Build Coastguard Worker bias term. 10808*da0073e9SAndroid Build Coastguard Worker 2. The fused form of addmm appears in the ultimate graph that's 10809*da0073e9SAndroid Build Coastguard Worker executed. 10810*da0073e9SAndroid Build Coastguard Worker 3. A sum op was emitted for accumulating gradients along the 0th 10811*da0073e9SAndroid Build Coastguard Worker (expanded) dimension of the bias term. 10812*da0073e9SAndroid Build Coastguard Worker 4. The correct symbolic representation for the backward pass of the 10813*da0073e9SAndroid Build Coastguard Worker mm operator was emitted (x.t() -> mm) 10814*da0073e9SAndroid Build Coastguard Worker 10815*da0073e9SAndroid Build Coastguard Worker TODO: we should actually check these conditions once we have a way 10816*da0073e9SAndroid Build Coastguard Worker to dump the GraphExecutor state. Namely the processed forward graph 10817*da0073e9SAndroid Build Coastguard Worker and the backward graph. 10818*da0073e9SAndroid Build Coastguard Worker """ 10819*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 10820*da0073e9SAndroid Build Coastguard Worker def addmm_grad_test(b, x, w): 10821*da0073e9SAndroid Build Coastguard Worker return torch.addmm(b, x, w) 10822*da0073e9SAndroid Build Coastguard Worker 10823*da0073e9SAndroid Build Coastguard Worker # Initialize param and input values 10824*da0073e9SAndroid Build Coastguard Worker w_init = torch.rand(2, 5) 10825*da0073e9SAndroid Build Coastguard Worker b_init = torch.rand(5) 10826*da0073e9SAndroid Build Coastguard Worker x = torch.rand(3, 2) 10827*da0073e9SAndroid Build Coastguard Worker 10828*da0073e9SAndroid Build Coastguard Worker # Clone trainable params 10829*da0073e9SAndroid Build Coastguard Worker b = b_init.clone() 10830*da0073e9SAndroid Build Coastguard Worker b.requires_grad_() 10831*da0073e9SAndroid Build Coastguard Worker w = w_init.clone() 10832*da0073e9SAndroid Build Coastguard Worker w.requires_grad_() 10833*da0073e9SAndroid Build Coastguard Worker 10834*da0073e9SAndroid Build Coastguard Worker # Test symbolic differentiation 10835*da0073e9SAndroid Build Coastguard Worker y = addmm_grad_test(b, x, w) 10836*da0073e9SAndroid Build Coastguard Worker y.sum().backward() 10837*da0073e9SAndroid Build Coastguard Worker 10838*da0073e9SAndroid Build Coastguard Worker # clone params for autograd reference 10839*da0073e9SAndroid Build Coastguard Worker b_ref = b_init.clone() 10840*da0073e9SAndroid Build Coastguard Worker b_ref.requires_grad_() 10841*da0073e9SAndroid Build Coastguard Worker w_ref = w_init.clone() 10842*da0073e9SAndroid Build Coastguard Worker w_ref.requires_grad_() 10843*da0073e9SAndroid Build Coastguard Worker y_ref = torch.addmm(b_ref, x, w_ref) 10844*da0073e9SAndroid Build Coastguard Worker y_ref.sum().backward() 10845*da0073e9SAndroid Build Coastguard Worker 10846*da0073e9SAndroid Build Coastguard Worker self.assertEqual(w.grad, w_ref.grad) 10847*da0073e9SAndroid Build Coastguard Worker self.assertEqual(b.grad, b_ref.grad) 10848*da0073e9SAndroid Build Coastguard Worker 10849*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not RUN_CUDA, "running tests on cuda to verify cudnn fix") 10850*da0073e9SAndroid Build Coastguard Worker def test_batch_norm_inference_backward_cuda(self): 10851*da0073e9SAndroid Build Coastguard Worker with enable_profiling_mode_for_profiling_tests(): 10852*da0073e9SAndroid Build Coastguard Worker class MyBatchNorm(torch.nn.Module): 10853*da0073e9SAndroid Build Coastguard Worker def __init__(self, num_features, affine, track_running_stats): 10854*da0073e9SAndroid Build Coastguard Worker super().__init__() 10855*da0073e9SAndroid Build Coastguard Worker self.bn = torch.nn.BatchNorm2d( 10856*da0073e9SAndroid Build Coastguard Worker num_features, 1e-5, affine=affine, track_running_stats=track_running_stats).float() 10857*da0073e9SAndroid Build Coastguard Worker 10858*da0073e9SAndroid Build Coastguard Worker def forward(self, x: torch.Tensor): 10859*da0073e9SAndroid Build Coastguard Worker o = self.bn(x) 10860*da0073e9SAndroid Build Coastguard Worker o = torch.nn.functional.relu(o) 10861*da0073e9SAndroid Build Coastguard Worker return o 10862*da0073e9SAndroid Build Coastguard Worker 10863*da0073e9SAndroid Build Coastguard Worker batch = 4 10864*da0073e9SAndroid Build Coastguard Worker c = 2 10865*da0073e9SAndroid Build Coastguard Worker hw = 3 10866*da0073e9SAndroid Build Coastguard Worker # Initialize param and input values 10867*da0073e9SAndroid Build Coastguard Worker x_init = torch.randn(batch, c, hw, hw, dtype=torch.float).cuda() 10868*da0073e9SAndroid Build Coastguard Worker grad = torch.randn(batch, c, hw, hw, dtype=torch.float).cuda() 10869*da0073e9SAndroid Build Coastguard Worker 10870*da0073e9SAndroid Build Coastguard Worker training = False 10871*da0073e9SAndroid Build Coastguard Worker affine = True 10872*da0073e9SAndroid Build Coastguard Worker track_running_stats = True 10873*da0073e9SAndroid Build Coastguard Worker 10874*da0073e9SAndroid Build Coastguard Worker module = torch.jit.script(MyBatchNorm(c, affine, track_running_stats)).cuda() 10875*da0073e9SAndroid Build Coastguard Worker ref_module = MyBatchNorm(c, affine, track_running_stats).cuda() 10876*da0073e9SAndroid Build Coastguard Worker module.eval() 10877*da0073e9SAndroid Build Coastguard Worker ref_module.eval() 10878*da0073e9SAndroid Build Coastguard Worker 10879*da0073e9SAndroid Build Coastguard Worker jit_module = torch.jit.script(module) 10880*da0073e9SAndroid Build Coastguard Worker ref_module.load_state_dict(module.state_dict()) 10881*da0073e9SAndroid Build Coastguard Worker 10882*da0073e9SAndroid Build Coastguard Worker x = x_init.detach().clone() 10883*da0073e9SAndroid Build Coastguard Worker x.requires_grad_() 10884*da0073e9SAndroid Build Coastguard Worker x_ref = x_init.detach().clone() 10885*da0073e9SAndroid Build Coastguard Worker x_ref.requires_grad_() 10886*da0073e9SAndroid Build Coastguard Worker 10887*da0073e9SAndroid Build Coastguard Worker # Test symbolic differentiation 10888*da0073e9SAndroid Build Coastguard Worker # Run Forward and Backward thrice to trigger autodiff graph 10889*da0073e9SAndroid Build Coastguard Worker for i in range(0, 3): 10890*da0073e9SAndroid Build Coastguard Worker y = jit_module(x) 10891*da0073e9SAndroid Build Coastguard Worker y.backward(grad) 10892*da0073e9SAndroid Build Coastguard Worker x.grad.zero_() 10893*da0073e9SAndroid Build Coastguard Worker 10894*da0073e9SAndroid Build Coastguard Worker module.bn.running_mean.zero_() 10895*da0073e9SAndroid Build Coastguard Worker module.bn.running_var.fill_(1.0) 10896*da0073e9SAndroid Build Coastguard Worker ref_module.bn.running_mean.zero_() 10897*da0073e9SAndroid Build Coastguard Worker ref_module.bn.running_var.fill_(1.0) 10898*da0073e9SAndroid Build Coastguard Worker 10899*da0073e9SAndroid Build Coastguard Worker # run jitted module 10900*da0073e9SAndroid Build Coastguard Worker y = jit_module(x) 10901*da0073e9SAndroid Build Coastguard Worker y.backward(grad) 10902*da0073e9SAndroid Build Coastguard Worker # reference computation 10903*da0073e9SAndroid Build Coastguard Worker y_ref = ref_module(x_ref) 10904*da0073e9SAndroid Build Coastguard Worker y_ref.backward(grad) 10905*da0073e9SAndroid Build Coastguard Worker 10906*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y_ref, y) 10907*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, x_ref.grad) 10908*da0073e9SAndroid Build Coastguard Worker self.assertEqual(module.bn.running_mean, ref_module.bn.running_mean) 10909*da0073e9SAndroid Build Coastguard Worker self.assertEqual(module.bn.running_var, ref_module.bn.running_var) 10910*da0073e9SAndroid Build Coastguard Worker 10911*da0073e9SAndroid Build Coastguard Worker def test_zeros(self): 10912*da0073e9SAndroid Build Coastguard Worker class M(torch.jit.ScriptModule): 10913*da0073e9SAndroid Build Coastguard Worker __constants__ = ['d'] 10914*da0073e9SAndroid Build Coastguard Worker 10915*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 10916*da0073e9SAndroid Build Coastguard Worker super().__init__() 10917*da0073e9SAndroid Build Coastguard Worker self.d = torch.device('cpu') 10918*da0073e9SAndroid Build Coastguard Worker 10919*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 10920*da0073e9SAndroid Build Coastguard Worker def create(self): 10921*da0073e9SAndroid Build Coastguard Worker return torch.zeros([1, 1, 2], dtype=torch.float, device=self.d, layout=torch.strided) 10922*da0073e9SAndroid Build Coastguard Worker 10923*da0073e9SAndroid Build Coastguard Worker r = M().create() 10924*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r.dtype, torch.float) 10925*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.zeros([1, 1, 2], dtype=torch.float), r) 10926*da0073e9SAndroid Build Coastguard Worker 10927*da0073e9SAndroid Build Coastguard Worker def fn(): 10928*da0073e9SAndroid Build Coastguard Worker return torch.zeros((1, 2, 3)) 10929*da0073e9SAndroid Build Coastguard Worker 10930*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, ()) 10931*da0073e9SAndroid Build Coastguard Worker 10932*da0073e9SAndroid Build Coastguard Worker def test_vararg_zeros(self): 10933*da0073e9SAndroid Build Coastguard Worker def foo(): 10934*da0073e9SAndroid Build Coastguard Worker return torch.zeros(3, 4, 5, dtype=torch.int) 10935*da0073e9SAndroid Build Coastguard Worker 10936*da0073e9SAndroid Build Coastguard Worker self.checkScript(foo, ()) 10937*da0073e9SAndroid Build Coastguard Worker 10938*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "the original version of test_rand") 10939*da0073e9SAndroid Build Coastguard Worker def test_rand(self): 10940*da0073e9SAndroid Build Coastguard Worker def test_rand(): 10941*da0073e9SAndroid Build Coastguard Worker a = torch.rand([3, 4]) 10942*da0073e9SAndroid Build Coastguard Worker return a + 1.0 - a 10943*da0073e9SAndroid Build Coastguard Worker 10944*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_rand, ()) 10945*da0073e9SAndroid Build Coastguard Worker fn = torch.jit.script(test_rand) 10946*da0073e9SAndroid Build Coastguard Worker out = fn() 10947*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out.dtype, torch.get_default_dtype()) 10948*da0073e9SAndroid Build Coastguard Worker g = fn.graph_for() 10949*da0073e9SAndroid Build Coastguard Worker # Testing shape analysis correctly setting type 10950*da0073e9SAndroid Build Coastguard Worker if GRAPH_EXECUTOR != ProfilingMode.SIMPLE: 10951*da0073e9SAndroid Build Coastguard Worker FileCheck().check("Double(*, *, requires_grad=0, device=cpu)") \ 10952*da0073e9SAndroid Build Coastguard Worker .check_not("Float(*, *, requires_grad=0, device=cpu)").run(g) 10953*da0073e9SAndroid Build Coastguard Worker 10954*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 10955*da0073e9SAndroid Build Coastguard Worker def randint(): 10956*da0073e9SAndroid Build Coastguard Worker return torch.randint(0, 5, [1, 2]) 10957*da0073e9SAndroid Build Coastguard Worker out = randint() 10958*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out.dtype, torch.int64) 10959*da0073e9SAndroid Build Coastguard Worker if GRAPH_EXECUTOR != ProfilingMode.SIMPLE: 10960*da0073e9SAndroid Build Coastguard Worker FileCheck().check("Long(*, *, requires_grad=0, device=cpu)") \ 10961*da0073e9SAndroid Build Coastguard Worker .check_not("Float(*, *, requires_grad=0, device=cpu)") \ 10962*da0073e9SAndroid Build Coastguard Worker .check_not("Double(*, *, requires_grad=0, device=cpu)") \ 10963*da0073e9SAndroid Build Coastguard Worker .run(randint.graph_for()) 10964*da0073e9SAndroid Build Coastguard Worker 10965*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not RUN_CUDA, "no CUDA") 10966*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "skip if profiling isn't enabled") 10967*da0073e9SAndroid Build Coastguard Worker def test_autodiff_complex(self): 10968*da0073e9SAndroid Build Coastguard Worker def foo(x: torch.Tensor, y: torch.Tensor, W: torch.Tensor): 10969*da0073e9SAndroid Build Coastguard Worker return torch.exp(torch.mm(torch.complex(x, y), W.cfloat())) 10970*da0073e9SAndroid Build Coastguard Worker 10971*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 10972*da0073e9SAndroid Build Coastguard Worker def jitted_foo(x: torch.Tensor, y: torch.Tensor, W: torch.Tensor): 10973*da0073e9SAndroid Build Coastguard Worker return torch.exp(torch.mm(torch.complex(x, y), W.cfloat())) 10974*da0073e9SAndroid Build Coastguard Worker 10975*da0073e9SAndroid Build Coastguard Worker x = torch.randn(128, 16, dtype=torch.float32, device='cuda:0') 10976*da0073e9SAndroid Build Coastguard Worker y = torch.randn(128, 16, dtype=torch.float32, device='cuda:0') 10977*da0073e9SAndroid Build Coastguard Worker W = torch.randn(16, 1, dtype=torch.float32, device='cuda:0', requires_grad=True) 10978*da0073e9SAndroid Build Coastguard Worker W.data /= 4 10979*da0073e9SAndroid Build Coastguard Worker 10980*da0073e9SAndroid Build Coastguard Worker with enable_profiling_mode_for_profiling_tests(): 10981*da0073e9SAndroid Build Coastguard Worker for i in range(4): 10982*da0073e9SAndroid Build Coastguard Worker self.assertTrue((foo(x, y, W).grad_fn is None) == (jitted_foo(x, y, W).grad_fn is None)) 10983*da0073e9SAndroid Build Coastguard Worker 10984*da0073e9SAndroid Build Coastguard Worker 10985*da0073e9SAndroid Build Coastguard Worker def test_linear_grad(self): 10986*da0073e9SAndroid Build Coastguard Worker with enable_profiling_mode_for_profiling_tests(): 10987*da0073e9SAndroid Build Coastguard Worker def t(x: torch.Tensor, w: torch.Tensor, b: Optional[torch.Tensor]): 10988*da0073e9SAndroid Build Coastguard Worker return torch.nn.functional.linear(x, w, b) 10989*da0073e9SAndroid Build Coastguard Worker 10990*da0073e9SAndroid Build Coastguard Worker x_init = torch.randn(4, 2) 10991*da0073e9SAndroid Build Coastguard Worker w_init = torch.randn(3, 2) 10992*da0073e9SAndroid Build Coastguard Worker b_init = torch.randn(3) 10993*da0073e9SAndroid Build Coastguard Worker grad = torch.randn(4, 3) 10994*da0073e9SAndroid Build Coastguard Worker 10995*da0073e9SAndroid Build Coastguard Worker with disable_autodiff_subgraph_inlining(): 10996*da0073e9SAndroid Build Coastguard Worker # script module 10997*da0073e9SAndroid Build Coastguard Worker jit_t = torch.jit.script(t) 10998*da0073e9SAndroid Build Coastguard Worker 10999*da0073e9SAndroid Build Coastguard Worker x = x_init.detach().requires_grad_() 11000*da0073e9SAndroid Build Coastguard Worker w = w_init.detach().requires_grad_() 11001*da0073e9SAndroid Build Coastguard Worker b = b_init.detach().requires_grad_() 11002*da0073e9SAndroid Build Coastguard Worker x_ref = x_init.detach().requires_grad_() 11003*da0073e9SAndroid Build Coastguard Worker w_ref = w_init.detach().requires_grad_() 11004*da0073e9SAndroid Build Coastguard Worker b_ref = b_init.detach().requires_grad_() 11005*da0073e9SAndroid Build Coastguard Worker 11006*da0073e9SAndroid Build Coastguard Worker # profiling/optimization runs 11007*da0073e9SAndroid Build Coastguard Worker jit_o = jit_t(x, w, b) 11008*da0073e9SAndroid Build Coastguard Worker jit_o.backward(grad) 11009*da0073e9SAndroid Build Coastguard Worker jit_o = jit_t(x, w, b) 11010*da0073e9SAndroid Build Coastguard Worker jit_o.backward(grad) 11011*da0073e9SAndroid Build Coastguard Worker 11012*da0073e9SAndroid Build Coastguard Worker x.grad.zero_() 11013*da0073e9SAndroid Build Coastguard Worker w.grad.zero_() 11014*da0073e9SAndroid Build Coastguard Worker b.grad.zero_() 11015*da0073e9SAndroid Build Coastguard Worker jit_o = jit_t(x, w, b) 11016*da0073e9SAndroid Build Coastguard Worker jit_o.backward(grad) 11017*da0073e9SAndroid Build Coastguard Worker o = t(x_ref, w_ref, b_ref) 11018*da0073e9SAndroid Build Coastguard Worker o.backward(grad) 11019*da0073e9SAndroid Build Coastguard Worker 11020*da0073e9SAndroid Build Coastguard Worker self.assertEqual(jit_o, o) 11021*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, x_ref.grad) 11022*da0073e9SAndroid Build Coastguard Worker self.assertEqual(w.grad, w_ref.grad) 11023*da0073e9SAndroid Build Coastguard Worker self.assertEqual(b.grad, b_ref.grad) 11024*da0073e9SAndroid Build Coastguard Worker 11025*da0073e9SAndroid Build Coastguard Worker x.grad.zero_() 11026*da0073e9SAndroid Build Coastguard Worker w.grad.zero_() 11027*da0073e9SAndroid Build Coastguard Worker x_ref.grad.zero_() 11028*da0073e9SAndroid Build Coastguard Worker w_ref.grad.zero_() 11029*da0073e9SAndroid Build Coastguard Worker jit_o = jit_t(x, w, None) 11030*da0073e9SAndroid Build Coastguard Worker jit_o.backward(grad) 11031*da0073e9SAndroid Build Coastguard Worker o = t(x_ref, w_ref, None) 11032*da0073e9SAndroid Build Coastguard Worker o.backward(grad) 11033*da0073e9SAndroid Build Coastguard Worker 11034*da0073e9SAndroid Build Coastguard Worker self.assertEqual(jit_o, o) 11035*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, x_ref.grad) 11036*da0073e9SAndroid Build Coastguard Worker self.assertEqual(w.grad, w_ref.grad) 11037*da0073e9SAndroid Build Coastguard Worker 11038*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("TorchDynamo doesn't support profile") 11039*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "the profiling version of test_rand") 11040*da0073e9SAndroid Build Coastguard Worker def test_rand_profiling(self): 11041*da0073e9SAndroid Build Coastguard Worker def test_rand(): 11042*da0073e9SAndroid Build Coastguard Worker a = torch.rand([3, 4]) 11043*da0073e9SAndroid Build Coastguard Worker return a + 1.0 - a 11044*da0073e9SAndroid Build Coastguard Worker 11045*da0073e9SAndroid Build Coastguard Worker # Testing shape analysis correctly setting type 11046*da0073e9SAndroid Build Coastguard Worker with enable_profiling_mode_for_profiling_tests(): 11047*da0073e9SAndroid Build Coastguard Worker with num_profiled_runs(1): 11048*da0073e9SAndroid Build Coastguard Worker fn = torch.jit.script(test_rand) 11049*da0073e9SAndroid Build Coastguard Worker out = fn() 11050*da0073e9SAndroid Build Coastguard Worker graph_str = torch.jit.last_executed_optimized_graph() 11051*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out.dtype, torch.float) 11052*da0073e9SAndroid Build Coastguard Worker FileCheck().check("Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu)") \ 11053*da0073e9SAndroid Build Coastguard Worker .check_not("Double(3, 4, strides=[4, 1], requires_grad=0, device=cpu)").run(graph_str) 11054*da0073e9SAndroid Build Coastguard Worker 11055*da0073e9SAndroid Build Coastguard Worker # fn = self.checkScript(test_rand, ()) 11056*da0073e9SAndroid Build Coastguard Worker # out = fn() 11057*da0073e9SAndroid Build Coastguard Worker # self.assertEqual(out.dtype, torch.float) 11058*da0073e9SAndroid Build Coastguard Worker 11059*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 11060*da0073e9SAndroid Build Coastguard Worker def randint(): 11061*da0073e9SAndroid Build Coastguard Worker return torch.randint(0, 5, [1, 2]) 11062*da0073e9SAndroid Build Coastguard Worker 11063*da0073e9SAndroid Build Coastguard Worker with enable_profiling_mode_for_profiling_tests(): 11064*da0073e9SAndroid Build Coastguard Worker with num_profiled_runs(1): 11065*da0073e9SAndroid Build Coastguard Worker out = randint() 11066*da0073e9SAndroid Build Coastguard Worker graph_str = torch.jit.last_executed_optimized_graph() 11067*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out.dtype, torch.int64) 11068*da0073e9SAndroid Build Coastguard Worker FileCheck().check("profiled_type=Long(1, 2, strides=[2, 1], requires_grad=0, device=cpu)").run(graph_str) 11069*da0073e9SAndroid Build Coastguard Worker 11070*da0073e9SAndroid Build Coastguard Worker 11071*da0073e9SAndroid Build Coastguard Worker def test_erase_number_types(self): 11072*da0073e9SAndroid Build Coastguard Worker def func(a): 11073*da0073e9SAndroid Build Coastguard Worker b = 7 + 1 + 3 11074*da0073e9SAndroid Build Coastguard Worker c = a + b 11075*da0073e9SAndroid Build Coastguard Worker c += b 11076*da0073e9SAndroid Build Coastguard Worker return c 11077*da0073e9SAndroid Build Coastguard Worker 11078*da0073e9SAndroid Build Coastguard Worker graph = torch.jit.script(func).graph 11079*da0073e9SAndroid Build Coastguard Worker FileCheck().check("int = prim::Constant").check("aten::add_").run(str(graph)) 11080*da0073e9SAndroid Build Coastguard Worker self.run_pass("erase_number_types", graph) 11081*da0073e9SAndroid Build Coastguard Worker FileCheck().check_not("int = prim::Constant").run(str(graph)) 11082*da0073e9SAndroid Build Coastguard Worker 11083*da0073e9SAndroid Build Coastguard Worker def test_refine_tuple_types(self): 11084*da0073e9SAndroid Build Coastguard Worker # TupleConstruct output type is not correct here. 11085*da0073e9SAndroid Build Coastguard Worker graph_str = """ 11086*da0073e9SAndroid Build Coastguard Worker graph(%a : Float(123), %b : Float(4, 5, 6)): 11087*da0073e9SAndroid Build Coastguard Worker %c : (Tensor, Tensor) = prim::TupleConstruct(%a, %b) 11088*da0073e9SAndroid Build Coastguard Worker return (%c) 11089*da0073e9SAndroid Build Coastguard Worker """ 11090*da0073e9SAndroid Build Coastguard Worker graph = parse_ir(graph_str) 11091*da0073e9SAndroid Build Coastguard Worker torch._C._jit_pass_refine_tuple_types(graph) 11092*da0073e9SAndroid Build Coastguard Worker 11093*da0073e9SAndroid Build Coastguard Worker # After the pass, the output type should've been updated. 11094*da0073e9SAndroid Build Coastguard Worker self.assertTrue('(Float(123), Float(4, 5, 6))' in str(graph.findNode('prim::TupleConstruct').output())) 11095*da0073e9SAndroid Build Coastguard Worker 11096*da0073e9SAndroid Build Coastguard Worker # TODO(henrytu): Add test for RefineTypes for NamedTuple when it's supported by IR parser. 11097*da0073e9SAndroid Build Coastguard Worker 11098*da0073e9SAndroid Build Coastguard Worker def test_remove_dropout(self): 11099*da0073e9SAndroid Build Coastguard Worker weight_0_shape = (20, 5) 11100*da0073e9SAndroid Build Coastguard Worker weight_1_shape = (20, 20) 11101*da0073e9SAndroid Build Coastguard Worker input_shape = (10, 5) 11102*da0073e9SAndroid Build Coastguard Worker 11103*da0073e9SAndroid Build Coastguard Worker class M(torch.nn.Module): 11104*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 11105*da0073e9SAndroid Build Coastguard Worker super().__init__() 11106*da0073e9SAndroid Build Coastguard Worker self.weight_0 = torch.nn.Parameter(torch.rand(weight_0_shape)) 11107*da0073e9SAndroid Build Coastguard Worker self.weight_1 = torch.nn.Parameter(torch.rand(weight_1_shape)) 11108*da0073e9SAndroid Build Coastguard Worker 11109*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 11110*da0073e9SAndroid Build Coastguard Worker o = F.linear(x, self.weight_0) 11111*da0073e9SAndroid Build Coastguard Worker o = F.dropout(o, training=self.training) 11112*da0073e9SAndroid Build Coastguard Worker o = F.linear(o, self.weight_1) 11113*da0073e9SAndroid Build Coastguard Worker return o 11114*da0073e9SAndroid Build Coastguard Worker 11115*da0073e9SAndroid Build Coastguard Worker data = torch.rand(input_shape) 11116*da0073e9SAndroid Build Coastguard Worker m = M() 11117*da0073e9SAndroid Build Coastguard Worker m = torch.jit.script(m) 11118*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r'Dropout removal module in training mode is not yet supported'): 11119*da0073e9SAndroid Build Coastguard Worker torch._C._jit_pass_remove_dropout(m._c) 11120*da0073e9SAndroid Build Coastguard Worker m.eval() 11121*da0073e9SAndroid Build Coastguard Worker ref_res = m(data) 11122*da0073e9SAndroid Build Coastguard Worker # Need to inline otherwise we see instances of Function. 11123*da0073e9SAndroid Build Coastguard Worker # We would have to use torch.linear/dropout to get around it otherwise. 11124*da0073e9SAndroid Build Coastguard Worker from torch.jit._recursive import wrap_cpp_module 11125*da0073e9SAndroid Build Coastguard Worker m = wrap_cpp_module(torch._C._freeze_module(m._c)) 11126*da0073e9SAndroid Build Coastguard Worker torch._C._jit_pass_remove_dropout(m._c) 11127*da0073e9SAndroid Build Coastguard Worker res = m(data) 11128*da0073e9SAndroid Build Coastguard Worker FileCheck().check_not("aten::dropout").run(str(m.graph)) 11129*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(ref_res, res, rtol=1e-2, atol=1e-3) 11130*da0073e9SAndroid Build Coastguard Worker 11131*da0073e9SAndroid Build Coastguard Worker def test_unfold_zero_dim(self): 11132*da0073e9SAndroid Build Coastguard Worker def fn(x): 11133*da0073e9SAndroid Build Coastguard Worker return x.unfold(0, 1, 1) 11134*da0073e9SAndroid Build Coastguard Worker 11135*da0073e9SAndroid Build Coastguard Worker graph = torch.jit.script(fn).graph 11136*da0073e9SAndroid Build Coastguard Worker torch._C._jit_pass_complete_shape_analysis(graph, (torch.tensor(0.39),), False) 11137*da0073e9SAndroid Build Coastguard Worker out_dims = fn(torch.tensor(0.3923)).ndim 11138*da0073e9SAndroid Build Coastguard Worker self.assertEqual(graph.findNode("aten::unfold").output().type().dim(), out_dims) 11139*da0073e9SAndroid Build Coastguard Worker 11140*da0073e9SAndroid Build Coastguard Worker def test_mm_batching(self): 11141*da0073e9SAndroid Build Coastguard Worker 11142*da0073e9SAndroid Build Coastguard Worker with enable_profiling_mode_for_profiling_tests(): 11143*da0073e9SAndroid Build Coastguard Worker lstm_cell = torch.jit.script(LSTMCellS) 11144*da0073e9SAndroid Build Coastguard Worker 11145*da0073e9SAndroid Build Coastguard Worker def lstm(x, hx, cx, w_ih, w_hh, b_ih, b_hh): 11146*da0073e9SAndroid Build Coastguard Worker for i in range(x.size(0)): 11147*da0073e9SAndroid Build Coastguard Worker hx, cx = lstm_cell(x[i], hx, cx, w_ih, w_hh, b_ih, b_hh) 11148*da0073e9SAndroid Build Coastguard Worker return hx 11149*da0073e9SAndroid Build Coastguard Worker 11150*da0073e9SAndroid Build Coastguard Worker slstm = torch.jit.script(lstm) 11151*da0073e9SAndroid Build Coastguard Worker 11152*da0073e9SAndroid Build Coastguard Worker inputs = get_lstm_inputs('cpu', training=True, seq_length=10) 11153*da0073e9SAndroid Build Coastguard Worker slstm(*inputs, profile_and_replay=True).sum().backward(retain_graph=True) 11154*da0073e9SAndroid Build Coastguard Worker if GRAPH_EXECUTOR == ProfilingMode.PROFILING: 11155*da0073e9SAndroid Build Coastguard Worker slstm(*inputs, profile_and_replay=True).sum().backward() 11156*da0073e9SAndroid Build Coastguard Worker 11157*da0073e9SAndroid Build Coastguard Worker fw_graph = slstm.graph_for(*inputs) 11158*da0073e9SAndroid Build Coastguard Worker if GRAPH_EXECUTOR == ProfilingMode.LEGACY: 11159*da0073e9SAndroid Build Coastguard Worker bw_graph = backward_graph(slstm, diff_graph_idx=0) 11160*da0073e9SAndroid Build Coastguard Worker self.assertTrue('prim::MMBatchSide' in str(fw_graph)) 11161*da0073e9SAndroid Build Coastguard Worker self.assertTrue('prim::MMTreeReduce' in str(bw_graph)) 11162*da0073e9SAndroid Build Coastguard Worker 11163*da0073e9SAndroid Build Coastguard Worker sout = slstm(*inputs) 11164*da0073e9SAndroid Build Coastguard Worker out = lstm(*inputs) 11165*da0073e9SAndroid Build Coastguard Worker self.assertEqual(sout, out) 11166*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.autograd.grad(sout.sum(), inputs), 11167*da0073e9SAndroid Build Coastguard Worker torch.autograd.grad(out.sum(), inputs)) 11168*da0073e9SAndroid Build Coastguard Worker 11169*da0073e9SAndroid Build Coastguard Worker def test_loop_unrolling(self): 11170*da0073e9SAndroid Build Coastguard Worker def fn(x): 11171*da0073e9SAndroid Build Coastguard Worker y = 0 11172*da0073e9SAndroid Build Coastguard Worker for i in range(int(x)): 11173*da0073e9SAndroid Build Coastguard Worker y -= i 11174*da0073e9SAndroid Build Coastguard Worker return y 11175*da0073e9SAndroid Build Coastguard Worker 11176*da0073e9SAndroid Build Coastguard Worker graph = torch.jit.script(fn).graph 11177*da0073e9SAndroid Build Coastguard Worker self.run_pass('loop_unrolling', graph) 11178*da0073e9SAndroid Build Coastguard Worker unroll_factor = 8 11179*da0073e9SAndroid Build Coastguard Worker FileCheck().check("prim::Loop").check_count("aten::sub", unroll_factor) \ 11180*da0073e9SAndroid Build Coastguard Worker .check("prim::Loop").check("aten::sub").run(str(graph)) 11181*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, (torch.tensor(10),)) 11182*da0073e9SAndroid Build Coastguard Worker 11183*da0073e9SAndroid Build Coastguard Worker def test_loop_unrolling_const(self): 11184*da0073e9SAndroid Build Coastguard Worker def fn(): 11185*da0073e9SAndroid Build Coastguard Worker y = 0 11186*da0073e9SAndroid Build Coastguard Worker for _ in range(10): 11187*da0073e9SAndroid Build Coastguard Worker y -= 1 11188*da0073e9SAndroid Build Coastguard Worker return y 11189*da0073e9SAndroid Build Coastguard Worker 11190*da0073e9SAndroid Build Coastguard Worker def fn2(): 11191*da0073e9SAndroid Build Coastguard Worker y = 0 11192*da0073e9SAndroid Build Coastguard Worker for i in range(10): 11193*da0073e9SAndroid Build Coastguard Worker y -= i 11194*da0073e9SAndroid Build Coastguard Worker return y 11195*da0073e9SAndroid Build Coastguard Worker 11196*da0073e9SAndroid Build Coastguard Worker def check(fn, name): 11197*da0073e9SAndroid Build Coastguard Worker graph = torch.jit.script(fn).graph 11198*da0073e9SAndroid Build Coastguard Worker self.run_pass('loop_unrolling', graph) 11199*da0073e9SAndroid Build Coastguard Worker # entirely unrolled 11200*da0073e9SAndroid Build Coastguard Worker FileCheck().check_not("prim::Loop'").run(str(graph)) 11201*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, ()) 11202*da0073e9SAndroid Build Coastguard Worker 11203*da0073e9SAndroid Build Coastguard Worker check(fn, 'add_const') 11204*da0073e9SAndroid Build Coastguard Worker check(fn2, 'add_iter') 11205*da0073e9SAndroid Build Coastguard Worker 11206*da0073e9SAndroid Build Coastguard Worker def test_loop_unrolling_nested(self): 11207*da0073e9SAndroid Build Coastguard Worker def fn(x): 11208*da0073e9SAndroid Build Coastguard Worker y = 0 11209*da0073e9SAndroid Build Coastguard Worker for _ in range(10): 11210*da0073e9SAndroid Build Coastguard Worker for j in range(int(x)): 11211*da0073e9SAndroid Build Coastguard Worker y -= j 11212*da0073e9SAndroid Build Coastguard Worker return y 11213*da0073e9SAndroid Build Coastguard Worker 11214*da0073e9SAndroid Build Coastguard Worker graph = torch.jit.script(fn).graph 11215*da0073e9SAndroid Build Coastguard Worker self.run_pass('loop_unrolling', graph) 11216*da0073e9SAndroid Build Coastguard Worker # inner loop with 8 subs followed by loop epilogue 11217*da0073e9SAndroid Build Coastguard Worker unroll_factor = 8 11218*da0073e9SAndroid Build Coastguard Worker FileCheck().check("prim::Loop").check("prim::Loop").check_count('aten::sub', unroll_factor) \ 11219*da0073e9SAndroid Build Coastguard Worker .check("prim::Loop").check("aten::sub").run(str(graph)) 11220*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, (torch.tensor(10),)) 11221*da0073e9SAndroid Build Coastguard Worker 11222*da0073e9SAndroid Build Coastguard Worker def test_loop_unroll_unused_counter(self): 11223*da0073e9SAndroid Build Coastguard Worker def fn(x): 11224*da0073e9SAndroid Build Coastguard Worker y = 0 11225*da0073e9SAndroid Build Coastguard Worker for _ in range(int(x)): 11226*da0073e9SAndroid Build Coastguard Worker y -= 1 11227*da0073e9SAndroid Build Coastguard Worker return y 11228*da0073e9SAndroid Build Coastguard Worker 11229*da0073e9SAndroid Build Coastguard Worker graph = torch.jit.script(fn).graph 11230*da0073e9SAndroid Build Coastguard Worker self.run_pass('loop_unrolling', graph) 11231*da0073e9SAndroid Build Coastguard Worker FileCheck().check("prim::Loop").check_not("aten::add").check("return") \ 11232*da0073e9SAndroid Build Coastguard Worker .run(str(graph)) 11233*da0073e9SAndroid Build Coastguard Worker 11234*da0073e9SAndroid Build Coastguard Worker def test_loop_unroll_negative(self): 11235*da0073e9SAndroid Build Coastguard Worker def fn(x): 11236*da0073e9SAndroid Build Coastguard Worker y = 0 11237*da0073e9SAndroid Build Coastguard Worker for _ in range(int(x)): 11238*da0073e9SAndroid Build Coastguard Worker y += 1 11239*da0073e9SAndroid Build Coastguard Worker return y 11240*da0073e9SAndroid Build Coastguard Worker 11241*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, (torch.tensor(-20),)) 11242*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, (torch.tensor(-2),)) 11243*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, (torch.tensor(-1),)) 11244*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, (torch.tensor(0),)) 11245*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, (torch.tensor(1),)) 11246*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, (torch.tensor(2),)) 11247*da0073e9SAndroid Build Coastguard Worker 11248*da0073e9SAndroid Build Coastguard Worker def test_where(self): 11249*da0073e9SAndroid Build Coastguard Worker def fn(x, y): 11250*da0073e9SAndroid Build Coastguard Worker return torch.where(x > 0.0, x, y) 11251*da0073e9SAndroid Build Coastguard Worker 11252*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, (torch.randn(3, 2, dtype=torch.float), torch.ones(3, 2, dtype=torch.float))) 11253*da0073e9SAndroid Build Coastguard Worker 11254*da0073e9SAndroid Build Coastguard Worker def test_where_method(self): 11255*da0073e9SAndroid Build Coastguard Worker def fn(x, y): 11256*da0073e9SAndroid Build Coastguard Worker return x.where(x > 0.0, y) 11257*da0073e9SAndroid Build Coastguard Worker 11258*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, (torch.randn(3, 2, dtype=torch.float), torch.ones(3, 2, dtype=torch.float))) 11259*da0073e9SAndroid Build Coastguard Worker 11260*da0073e9SAndroid Build Coastguard Worker def test_union_to_number(self): 11261*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 11262*da0073e9SAndroid Build Coastguard Worker def fn(x: Union[int, complex, float], y: Union[int, complex, float]): 11263*da0073e9SAndroid Build Coastguard Worker return x + y 11264*da0073e9SAndroid Build Coastguard Worker FileCheck().check(": Scalar):").run(fn.graph) 11265*da0073e9SAndroid Build Coastguard Worker 11266*da0073e9SAndroid Build Coastguard Worker def test_reassign_module_lhs(self): 11267*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'Cannot re-assign \'self\''): 11268*da0073e9SAndroid Build Coastguard Worker class ReassignSelfLHS(torch.jit.ScriptModule): 11269*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 11270*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 11271*da0073e9SAndroid Build Coastguard Worker for _ in range(20): 11272*da0073e9SAndroid Build Coastguard Worker self = x 11273*da0073e9SAndroid Build Coastguard Worker return self 11274*da0073e9SAndroid Build Coastguard Worker 11275*da0073e9SAndroid Build Coastguard Worker ReassignSelfLHS() 11276*da0073e9SAndroid Build Coastguard Worker 11277*da0073e9SAndroid Build Coastguard Worker def test_reassign_module_rhs(self): 11278*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'Cannot re-assign \'x\' to a value of type module'): 11279*da0073e9SAndroid Build Coastguard Worker class ReassignSelfRHS(torch.jit.ScriptModule): 11280*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 11281*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 11282*da0073e9SAndroid Build Coastguard Worker for _ in range(20): 11283*da0073e9SAndroid Build Coastguard Worker x = self 11284*da0073e9SAndroid Build Coastguard Worker return self 11285*da0073e9SAndroid Build Coastguard Worker 11286*da0073e9SAndroid Build Coastguard Worker ReassignSelfRHS() 11287*da0073e9SAndroid Build Coastguard Worker 11288*da0073e9SAndroid Build Coastguard Worker def test_unknown_builtin(self): 11289*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'object has no attribute or method'): 11290*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 11291*da0073e9SAndroid Build Coastguard Worker def unknown_builtin(x): 11292*da0073e9SAndroid Build Coastguard Worker return x.splork(3) 11293*da0073e9SAndroid Build Coastguard Worker 11294*da0073e9SAndroid Build Coastguard Worker def test_return_tuple(self): 11295*da0073e9SAndroid Build Coastguard Worker def return_tuple(x): 11296*da0073e9SAndroid Build Coastguard Worker a = (x, x) 11297*da0073e9SAndroid Build Coastguard Worker return a, x 11298*da0073e9SAndroid Build Coastguard Worker self.checkScript(return_tuple, (torch.rand(4),)) 11299*da0073e9SAndroid Build Coastguard Worker 11300*da0073e9SAndroid Build Coastguard Worker def test_add_tuple_optional(self): 11301*da0073e9SAndroid Build Coastguard Worker def foo(input: Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]) -> Optional[torch.Tensor]: 11302*da0073e9SAndroid Build Coastguard Worker changed_input = input[0] + 1 11303*da0073e9SAndroid Build Coastguard Worker value: Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]] = (changed_input,) + input[1:] 11304*da0073e9SAndroid Build Coastguard Worker return value[2] 11305*da0073e9SAndroid Build Coastguard Worker inp: Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]] = (torch.rand(4), None, None) 11306*da0073e9SAndroid Build Coastguard Worker self.checkScript(foo, (inp,)) 11307*da0073e9SAndroid Build Coastguard Worker 11308*da0073e9SAndroid Build Coastguard Worker def test_add_tuple_non_optional(self): 11309*da0073e9SAndroid Build Coastguard Worker def foo(input: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]) -> torch.Tensor: 11310*da0073e9SAndroid Build Coastguard Worker changed_input = input[0] + 1 11311*da0073e9SAndroid Build Coastguard Worker value: Tuple[torch.Tensor, torch.Tensor, torch.Tensor] = (changed_input,) + input[1:] 11312*da0073e9SAndroid Build Coastguard Worker return torch.sum(value[2]) + 4 11313*da0073e9SAndroid Build Coastguard Worker inp: Tuple[torch.Tensor, torch.Tensor, torch.Tensor] = (torch.rand(4), torch.rand(4), torch.rand(4)) 11314*da0073e9SAndroid Build Coastguard Worker self.checkScript(foo, (inp,)) 11315*da0073e9SAndroid Build Coastguard Worker 11316*da0073e9SAndroid Build Coastguard Worker def test_add_tuple_different_types(self): 11317*da0073e9SAndroid Build Coastguard Worker def foo(a: Tuple[int, float], b: Tuple[int]) -> int: 11318*da0073e9SAndroid Build Coastguard Worker c: Tuple[int, float, int] = a + b 11319*da0073e9SAndroid Build Coastguard Worker d: Tuple[int, float, int, int] = c + b 11320*da0073e9SAndroid Build Coastguard Worker return d[3] + 1 11321*da0073e9SAndroid Build Coastguard Worker a = (1, 2.0) 11322*da0073e9SAndroid Build Coastguard Worker b = (3,) 11323*da0073e9SAndroid Build Coastguard Worker self.checkScript(foo, (a, b)) 11324*da0073e9SAndroid Build Coastguard Worker 11325*da0073e9SAndroid Build Coastguard Worker def test_add_tuple_same_types(self): 11326*da0073e9SAndroid Build Coastguard Worker def foo(a: Tuple[int, int], b: Tuple[int, int, int]) -> int: 11327*da0073e9SAndroid Build Coastguard Worker c: Tuple[int, int, int, int, int] = a + b 11328*da0073e9SAndroid Build Coastguard Worker d: Tuple[int, int, int, int, int, int, int, int] = c + b 11329*da0073e9SAndroid Build Coastguard Worker return d[6] - 2 11330*da0073e9SAndroid Build Coastguard Worker a = (1, 2) 11331*da0073e9SAndroid Build Coastguard Worker b = (3, 4, 5) 11332*da0073e9SAndroid Build Coastguard Worker self.checkScript(foo, (a, b)) 11333*da0073e9SAndroid Build Coastguard Worker 11334*da0073e9SAndroid Build Coastguard Worker def test_method_no_self(self): 11335*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'methods must have a self argument'): 11336*da0073e9SAndroid Build Coastguard Worker class MethodNoSelf(torch.jit.ScriptModule): 11337*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method # noqa: B902 11338*da0073e9SAndroid Build Coastguard Worker def forward(): # noqa: B902 11339*da0073e9SAndroid Build Coastguard Worker return torch.zeros(3, 4) 11340*da0073e9SAndroid Build Coastguard Worker 11341*da0073e9SAndroid Build Coastguard Worker MethodNoSelf() 11342*da0073e9SAndroid Build Coastguard Worker 11343*da0073e9SAndroid Build Coastguard Worker def test_return_stmt_not_at_end(self): 11344*da0073e9SAndroid Build Coastguard Worker def return_stmt(x): 11345*da0073e9SAndroid Build Coastguard Worker if bool(x > 3): 11346*da0073e9SAndroid Build Coastguard Worker return x + 3 11347*da0073e9SAndroid Build Coastguard Worker else: 11348*da0073e9SAndroid Build Coastguard Worker return x 11349*da0073e9SAndroid Build Coastguard Worker self.checkScript(return_stmt, (torch.rand(1),)) 11350*da0073e9SAndroid Build Coastguard Worker 11351*da0073e9SAndroid Build Coastguard Worker def test_for_in_range(self): 11352*da0073e9SAndroid Build Coastguard Worker def fn(): 11353*da0073e9SAndroid Build Coastguard Worker c = 0 11354*da0073e9SAndroid Build Coastguard Worker for i in range(100): 11355*da0073e9SAndroid Build Coastguard Worker c += i 11356*da0073e9SAndroid Build Coastguard Worker return c 11357*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, ()) 11358*da0073e9SAndroid Build Coastguard Worker 11359*da0073e9SAndroid Build Coastguard Worker def test_for_in_range_dynamic(self): 11360*da0073e9SAndroid Build Coastguard Worker def fn(): 11361*da0073e9SAndroid Build Coastguard Worker c = 0 11362*da0073e9SAndroid Build Coastguard Worker for i in range(100): 11363*da0073e9SAndroid Build Coastguard Worker acc = 0 11364*da0073e9SAndroid Build Coastguard Worker for j in range(i): 11365*da0073e9SAndroid Build Coastguard Worker acc += j 11366*da0073e9SAndroid Build Coastguard Worker c += acc 11367*da0073e9SAndroid Build Coastguard Worker return c 11368*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, (), optimize=False) 11369*da0073e9SAndroid Build Coastguard Worker 11370*da0073e9SAndroid Build Coastguard Worker def test_for_in_range_ast(self): 11371*da0073e9SAndroid Build Coastguard Worker def test_script_for_in_range_ast(): 11372*da0073e9SAndroid Build Coastguard Worker c = 0 11373*da0073e9SAndroid Build Coastguard Worker for i in range(100): 11374*da0073e9SAndroid Build Coastguard Worker acc = 0 11375*da0073e9SAndroid Build Coastguard Worker for j in range(i): 11376*da0073e9SAndroid Build Coastguard Worker acc += j 11377*da0073e9SAndroid Build Coastguard Worker c += acc 11378*da0073e9SAndroid Build Coastguard Worker return c 11379*da0073e9SAndroid Build Coastguard Worker 11380*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_script_for_in_range_ast, ()) 11381*da0073e9SAndroid Build Coastguard Worker 11382*da0073e9SAndroid Build Coastguard Worker def test_for_in_range_if_ast(self): 11383*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 11384*da0073e9SAndroid Build Coastguard Worker def test_script_for_in_range_if_ast(x): 11385*da0073e9SAndroid Build Coastguard Worker output = x 11386*da0073e9SAndroid Build Coastguard Worker for i in range(20): 11387*da0073e9SAndroid Build Coastguard Worker if i == 0: 11388*da0073e9SAndroid Build Coastguard Worker output = x.unsqueeze(0) 11389*da0073e9SAndroid Build Coastguard Worker else: 11390*da0073e9SAndroid Build Coastguard Worker output = torch.cat((output, x.unsqueeze(0)), dim=0) 11391*da0073e9SAndroid Build Coastguard Worker return output 11392*da0073e9SAndroid Build Coastguard Worker inputs = self._make_scalar_vars([0], torch.int64) 11393*da0073e9SAndroid Build Coastguard Worker 11394*da0073e9SAndroid Build Coastguard Worker self.assertEqual(test_script_for_in_range_if_ast(*inputs).shape[0], 20) 11395*da0073e9SAndroid Build Coastguard Worker 11396*da0073e9SAndroid Build Coastguard Worker def test_for_in_range_start_end(self): 11397*da0073e9SAndroid Build Coastguard Worker def fn(): 11398*da0073e9SAndroid Build Coastguard Worker x = 0 11399*da0073e9SAndroid Build Coastguard Worker for i in range(7, 100): 11400*da0073e9SAndroid Build Coastguard Worker x += i 11401*da0073e9SAndroid Build Coastguard Worker return x 11402*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, ()) 11403*da0073e9SAndroid Build Coastguard Worker 11404*da0073e9SAndroid Build Coastguard Worker def test_for_in_range_start_end_step(self): 11405*da0073e9SAndroid Build Coastguard Worker def fn(start, end, step): 11406*da0073e9SAndroid Build Coastguard Worker # type: (int, int, int) -> int 11407*da0073e9SAndroid Build Coastguard Worker x = 0 11408*da0073e9SAndroid Build Coastguard Worker for i in range(start, end, step): 11409*da0073e9SAndroid Build Coastguard Worker x += i 11410*da0073e9SAndroid Build Coastguard Worker return x 11411*da0073e9SAndroid Build Coastguard Worker 11412*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, (7, 100, 7)) 11413*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, (7, 100, -7)) 11414*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, (2, -11, -3)) 11415*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, (2, -11, 3)) 11416*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, (2, 10, 3)) 11417*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, (-2, -10, -10)) 11418*da0073e9SAndroid Build Coastguard Worker 11419*da0073e9SAndroid Build Coastguard Worker def test_for_in_range_zero_step(self): 11420*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 11421*da0073e9SAndroid Build Coastguard Worker def fn(): 11422*da0073e9SAndroid Build Coastguard Worker x = 0 11423*da0073e9SAndroid Build Coastguard Worker for i in range(2, -11, 0): 11424*da0073e9SAndroid Build Coastguard Worker x += i 11425*da0073e9SAndroid Build Coastguard Worker return x 11426*da0073e9SAndroid Build Coastguard Worker 11427*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "must not be zero"): 11428*da0073e9SAndroid Build Coastguard Worker fn() 11429*da0073e9SAndroid Build Coastguard Worker 11430*da0073e9SAndroid Build Coastguard Worker def test_range_args(self): 11431*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r'range expected at least 1 arguments, got 0'): 11432*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 11433*da0073e9SAndroid Build Coastguard Worker def range_no_arg(x): 11434*da0073e9SAndroid Build Coastguard Worker for _ in range(): 11435*da0073e9SAndroid Build Coastguard Worker x += 1 11436*da0073e9SAndroid Build Coastguard Worker return x 11437*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r'found float'): 11438*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 11439*da0073e9SAndroid Build Coastguard Worker def range_non_float(): 11440*da0073e9SAndroid Build Coastguard Worker for i in range(.5): 11441*da0073e9SAndroid Build Coastguard Worker print(i) 11442*da0073e9SAndroid Build Coastguard Worker 11443*da0073e9SAndroid Build Coastguard Worker def test_parse_empty_tuple_annotation(self): 11444*da0073e9SAndroid Build Coastguard Worker cu = torch.jit.CompilationUnit(''' 11445*da0073e9SAndroid Build Coastguard Worker def foo(x : Tuple[()]) -> Tuple[()]: 11446*da0073e9SAndroid Build Coastguard Worker return x 11447*da0073e9SAndroid Build Coastguard Worker ''') 11448*da0073e9SAndroid Build Coastguard Worker 11449*da0073e9SAndroid Build Coastguard Worker foo_code = cu.find_function('foo').code 11450*da0073e9SAndroid Build Coastguard Worker FileCheck().check("Tuple[()]").check("Tuple[()]").run(foo_code) 11451*da0073e9SAndroid Build Coastguard Worker 11452*da0073e9SAndroid Build Coastguard Worker def test_parse_empty_tuple_annotation_element_error(self): 11453*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 11454*da0073e9SAndroid Build Coastguard Worker RuntimeError, 'Tuple literal in Tuple type annotation must not have any elements'): 11455*da0073e9SAndroid Build Coastguard Worker cu = torch.jit.CompilationUnit(''' 11456*da0073e9SAndroid Build Coastguard Worker def foo(x : Tuple[(int,)]) -> Tuple[(int,)]: 11457*da0073e9SAndroid Build Coastguard Worker return x 11458*da0073e9SAndroid Build Coastguard Worker ''') 11459*da0073e9SAndroid Build Coastguard Worker 11460*da0073e9SAndroid Build Coastguard Worker def test_parse_none_type_annotation(self): 11461*da0073e9SAndroid Build Coastguard Worker cu = torch.jit.CompilationUnit(''' 11462*da0073e9SAndroid Build Coastguard Worker def foo(x : NoneType) -> NoneType: 11463*da0073e9SAndroid Build Coastguard Worker return x 11464*da0073e9SAndroid Build Coastguard Worker ''') 11465*da0073e9SAndroid Build Coastguard Worker 11466*da0073e9SAndroid Build Coastguard Worker foo_code = cu.find_function('foo').code 11467*da0073e9SAndroid Build Coastguard Worker FileCheck().check(": NoneType").check("-> NoneType").run(foo_code) 11468*da0073e9SAndroid Build Coastguard Worker 11469*da0073e9SAndroid Build Coastguard Worker def test_empty_tuple_str(self): 11470*da0073e9SAndroid Build Coastguard Worker empty_tuple_type = torch._C.TupleType([]) 11471*da0073e9SAndroid Build Coastguard Worker g = {'Tuple' : typing.Tuple} 11472*da0073e9SAndroid Build Coastguard Worker python_type = eval(empty_tuple_type.annotation_str, g) 11473*da0073e9SAndroid Build Coastguard Worker assert python_type is typing.Tuple[()] 11474*da0073e9SAndroid Build Coastguard Worker 11475*da0073e9SAndroid Build Coastguard Worker def test_tuple_str(self): 11476*da0073e9SAndroid Build Coastguard Worker tuple1_type = torch._C.TupleType([torch._C.StringType.get()]) 11477*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tuple1_type.annotation_str, "Tuple[str]") 11478*da0073e9SAndroid Build Coastguard Worker tuple2_type = torch._C.TupleType([torch._C.StringType.get(), torch._C.StringType.get()]) 11479*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tuple2_type.annotation_str, "Tuple[str, str]") 11480*da0073e9SAndroid Build Coastguard Worker 11481*da0073e9SAndroid Build Coastguard Worker def test_dict_str(self): 11482*da0073e9SAndroid Build Coastguard Worker dict_type = torch._C.DictType(torch._C.StringType.get(), torch._C.StringType.get()) 11483*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dict_type.annotation_str, "Dict[str, str]") 11484*da0073e9SAndroid Build Coastguard Worker 11485*da0073e9SAndroid Build Coastguard Worker def test_none_type_str(self): 11486*da0073e9SAndroid Build Coastguard Worker none_type = torch._C.NoneType.get() 11487*da0073e9SAndroid Build Coastguard Worker g = {'NoneType' : type(None)} 11488*da0073e9SAndroid Build Coastguard Worker python_type = eval(none_type.annotation_str, g) 11489*da0073e9SAndroid Build Coastguard Worker assert python_type is type(None) 11490*da0073e9SAndroid Build Coastguard Worker 11491*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("TorchDynamo fails with unknown reason") 11492*da0073e9SAndroid Build Coastguard Worker def test_zip_enumerate_modulelist(self): 11493*da0073e9SAndroid Build Coastguard Worker class Sub(torch.nn.Module): 11494*da0073e9SAndroid Build Coastguard Worker def forward(self, thing): 11495*da0073e9SAndroid Build Coastguard Worker return thing - 2 11496*da0073e9SAndroid Build Coastguard Worker 11497*da0073e9SAndroid Build Coastguard Worker class Double(torch.nn.Module): 11498*da0073e9SAndroid Build Coastguard Worker def forward(self, thing): 11499*da0073e9SAndroid Build Coastguard Worker return thing * 2 11500*da0073e9SAndroid Build Coastguard Worker 11501*da0073e9SAndroid Build Coastguard Worker # zipping over two 11502*da0073e9SAndroid Build Coastguard Worker class ZipModLists(torch.nn.Module): 11503*da0073e9SAndroid Build Coastguard Worker def __init__(self, mods, mods2): 11504*da0073e9SAndroid Build Coastguard Worker super().__init__() 11505*da0073e9SAndroid Build Coastguard Worker self.mods = mods 11506*da0073e9SAndroid Build Coastguard Worker self.mods2 = mods2 11507*da0073e9SAndroid Build Coastguard Worker 11508*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 11509*da0073e9SAndroid Build Coastguard Worker iter = 0 11510*da0073e9SAndroid Build Coastguard Worker for mod1, mod2 in zip(self.mods, self.mods2): 11511*da0073e9SAndroid Build Coastguard Worker x = mod2(mod1(x)) 11512*da0073e9SAndroid Build Coastguard Worker iter += 1 11513*da0073e9SAndroid Build Coastguard Worker return x, iter 11514*da0073e9SAndroid Build Coastguard Worker 11515*da0073e9SAndroid Build Coastguard Worker class ZipWithValues(torch.nn.Module): 11516*da0073e9SAndroid Build Coastguard Worker __constants__ = ['tup_larger', 'tup_smaller'] 11517*da0073e9SAndroid Build Coastguard Worker 11518*da0073e9SAndroid Build Coastguard Worker def __init__(self, mods, mods2): 11519*da0073e9SAndroid Build Coastguard Worker super().__init__() 11520*da0073e9SAndroid Build Coastguard Worker self.mods = mods 11521*da0073e9SAndroid Build Coastguard Worker self.mods2 = mods2 11522*da0073e9SAndroid Build Coastguard Worker self.tup_larger = list(range(len(mods2) + 1)) 11523*da0073e9SAndroid Build Coastguard Worker self.tup_smaller = list(range(max(len(mods2) + 1, 1))) 11524*da0073e9SAndroid Build Coastguard Worker 11525*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 11526*da0073e9SAndroid Build Coastguard Worker iter = 0 11527*da0073e9SAndroid Build Coastguard Worker x2 = x 11528*da0073e9SAndroid Build Coastguard Worker for val, mod1, mod2 in zip(self.tup_larger, self.mods, self.mods2): 11529*da0073e9SAndroid Build Coastguard Worker x = mod2(mod1(x)) + val 11530*da0073e9SAndroid Build Coastguard Worker iter += 1 11531*da0073e9SAndroid Build Coastguard Worker for val, mod1, mod2 in zip(self.tup_smaller, self.mods, self.mods2): 11532*da0073e9SAndroid Build Coastguard Worker x2 = mod2(mod1(x2)) + val 11533*da0073e9SAndroid Build Coastguard Worker iter += 1 11534*da0073e9SAndroid Build Coastguard Worker return x, iter 11535*da0073e9SAndroid Build Coastguard Worker 11536*da0073e9SAndroid Build Coastguard Worker mods = nn.ModuleList([Double()]), nn.ModuleList([Double(), Sub(), Sub()]), nn.ModuleList([Sub(), Double()]) 11537*da0073e9SAndroid Build Coastguard Worker for i in range(len(mods)): 11538*da0073e9SAndroid Build Coastguard Worker for j in range(len(mods)): 11539*da0073e9SAndroid Build Coastguard Worker mod = ZipModLists(mods[i], mods[j]) 11540*da0073e9SAndroid Build Coastguard Worker self.checkModule(mod, (torch.tensor(.5),)) 11541*da0073e9SAndroid Build Coastguard Worker mod2 = ZipWithValues(mods[i], mods[j]) 11542*da0073e9SAndroid Build Coastguard Worker self.checkModule(mod2, (torch.tensor(.5),)) 11543*da0073e9SAndroid Build Coastguard Worker 11544*da0073e9SAndroid Build Coastguard Worker 11545*da0073e9SAndroid Build Coastguard Worker def test_enumerate_modlist_range(self): 11546*da0073e9SAndroid Build Coastguard Worker class Double(torch.nn.Module): 11547*da0073e9SAndroid Build Coastguard Worker def forward(self, thing): 11548*da0073e9SAndroid Build Coastguard Worker return thing * 2 11549*da0073e9SAndroid Build Coastguard Worker 11550*da0073e9SAndroid Build Coastguard Worker class Mod(torch.nn.Module): 11551*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 11552*da0073e9SAndroid Build Coastguard Worker super().__init__() 11553*da0073e9SAndroid Build Coastguard Worker self.mods = nn.ModuleList([Double(), Double()]) 11554*da0073e9SAndroid Build Coastguard Worker 11555*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 11556*da0073e9SAndroid Build Coastguard Worker x2 = x 11557*da0073e9SAndroid Build Coastguard Worker iter = 0 11558*da0073e9SAndroid Build Coastguard Worker for val, mod in enumerate(self.mods): 11559*da0073e9SAndroid Build Coastguard Worker x2 = mod(x2) * val 11560*da0073e9SAndroid Build Coastguard Worker iter += 1 11561*da0073e9SAndroid Build Coastguard Worker return iter, x, x2 11562*da0073e9SAndroid Build Coastguard Worker 11563*da0073e9SAndroid Build Coastguard Worker self.checkModule(Mod(), (torch.tensor(.5),)) 11564*da0073e9SAndroid Build Coastguard Worker 11565*da0073e9SAndroid Build Coastguard Worker # variable length, modulelist 11566*da0073e9SAndroid Build Coastguard Worker class Mod2(Mod): 11567*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 11568*da0073e9SAndroid Build Coastguard Worker for val, mod in zip(range(int(x)), self.mods): 11569*da0073e9SAndroid Build Coastguard Worker x = mod(x) * val 11570*da0073e9SAndroid Build Coastguard Worker return x 11571*da0073e9SAndroid Build Coastguard Worker 11572*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(Exception, "that does not have a statically determinable length"): 11573*da0073e9SAndroid Build Coastguard Worker torch.jit.script(Mod2()) 11574*da0073e9SAndroid Build Coastguard Worker 11575*da0073e9SAndroid Build Coastguard Worker # modulelist, variable length 11576*da0073e9SAndroid Build Coastguard Worker class Mod3(Mod): 11577*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 11578*da0073e9SAndroid Build Coastguard Worker for val, mod in zip(self.mods, range(int(x))): 11579*da0073e9SAndroid Build Coastguard Worker x = mod(x) * val 11580*da0073e9SAndroid Build Coastguard Worker return x 11581*da0073e9SAndroid Build Coastguard Worker 11582*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(Exception, "that does not have a statically determinable length"): 11583*da0073e9SAndroid Build Coastguard Worker torch.jit.script(Mod3()) 11584*da0073e9SAndroid Build Coastguard Worker 11585*da0073e9SAndroid Build Coastguard Worker def test_for_in_enumerate(self): 11586*da0073e9SAndroid Build Coastguard Worker def fn(x): 11587*da0073e9SAndroid Build Coastguard Worker # type: (List[int]) -> int 11588*da0073e9SAndroid Build Coastguard Worker sum = 0 11589*da0073e9SAndroid Build Coastguard Worker for (i, v) in enumerate(x): 11590*da0073e9SAndroid Build Coastguard Worker sum += i * v 11591*da0073e9SAndroid Build Coastguard Worker 11592*da0073e9SAndroid Build Coastguard Worker return sum 11593*da0073e9SAndroid Build Coastguard Worker 11594*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, ([1, 2, 3, 4, 5],)) 11595*da0073e9SAndroid Build Coastguard Worker 11596*da0073e9SAndroid Build Coastguard Worker def fn_enumerate_start_arg(x): 11597*da0073e9SAndroid Build Coastguard Worker # type: (List[int]) -> int 11598*da0073e9SAndroid Build Coastguard Worker sum = 0 11599*da0073e9SAndroid Build Coastguard Worker for (i, v) in enumerate(x, 1): 11600*da0073e9SAndroid Build Coastguard Worker sum += i * v 11601*da0073e9SAndroid Build Coastguard Worker 11602*da0073e9SAndroid Build Coastguard Worker return sum 11603*da0073e9SAndroid Build Coastguard Worker 11604*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn_enumerate_start_arg, ([1, 2, 3, 4, 5],)) 11605*da0073e9SAndroid Build Coastguard Worker 11606*da0073e9SAndroid Build Coastguard Worker def fn_enumerate_start_kwarg(x): 11607*da0073e9SAndroid Build Coastguard Worker # type: (List[int]) -> int 11608*da0073e9SAndroid Build Coastguard Worker sum = 0 11609*da0073e9SAndroid Build Coastguard Worker for (i, v) in enumerate(x, start=1): 11610*da0073e9SAndroid Build Coastguard Worker sum += i * v 11611*da0073e9SAndroid Build Coastguard Worker 11612*da0073e9SAndroid Build Coastguard Worker return sum 11613*da0073e9SAndroid Build Coastguard Worker 11614*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn_enumerate_start_kwarg, ([1, 2, 3, 4, 5],)) 11615*da0073e9SAndroid Build Coastguard Worker 11616*da0073e9SAndroid Build Coastguard Worker def fn_nested_enumerate(x): 11617*da0073e9SAndroid Build Coastguard Worker # type: (List[int]) -> int 11618*da0073e9SAndroid Build Coastguard Worker sum = 0 11619*da0073e9SAndroid Build Coastguard Worker for (i, (j, v)) in enumerate(enumerate(x)): 11620*da0073e9SAndroid Build Coastguard Worker sum += i * j * v 11621*da0073e9SAndroid Build Coastguard Worker 11622*da0073e9SAndroid Build Coastguard Worker return sum 11623*da0073e9SAndroid Build Coastguard Worker 11624*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn_nested_enumerate, ([1, 2, 3, 4, 5],)) 11625*da0073e9SAndroid Build Coastguard Worker 11626*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r'enumerate expected at least 1 arguments, got 0'): 11627*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 11628*da0073e9SAndroid Build Coastguard Worker def enumerate_no_arg(x): 11629*da0073e9SAndroid Build Coastguard Worker # type: (List[int]) -> int 11630*da0073e9SAndroid Build Coastguard Worker sum = 0 11631*da0073e9SAndroid Build Coastguard Worker for _ in enumerate(): 11632*da0073e9SAndroid Build Coastguard Worker sum += 1 11633*da0073e9SAndroid Build Coastguard Worker 11634*da0073e9SAndroid Build Coastguard Worker return sum 11635*da0073e9SAndroid Build Coastguard Worker 11636*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r'enumerate expected at most 2 arguments, got 3'): 11637*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 11638*da0073e9SAndroid Build Coastguard Worker def enumerate_too_many_args(x): 11639*da0073e9SAndroid Build Coastguard Worker # type: (List[int]) -> int 11640*da0073e9SAndroid Build Coastguard Worker sum = 0 11641*da0073e9SAndroid Build Coastguard Worker for _ in enumerate(x, x, x): 11642*da0073e9SAndroid Build Coastguard Worker sum += 1 11643*da0073e9SAndroid Build Coastguard Worker 11644*da0073e9SAndroid Build Coastguard Worker return sum 11645*da0073e9SAndroid Build Coastguard Worker 11646*da0073e9SAndroid Build Coastguard Worker def test_list_comprehension_modulelist(self): 11647*da0073e9SAndroid Build Coastguard Worker class Inner(torch.nn.Module): 11648*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 11649*da0073e9SAndroid Build Coastguard Worker return x + 10 11650*da0073e9SAndroid Build Coastguard Worker 11651*da0073e9SAndroid Build Coastguard Worker class M(torch.nn.Module): 11652*da0073e9SAndroid Build Coastguard Worker def __init__(self, mod_list): 11653*da0073e9SAndroid Build Coastguard Worker super().__init__() 11654*da0073e9SAndroid Build Coastguard Worker self.module_list = mod_list 11655*da0073e9SAndroid Build Coastguard Worker 11656*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 11657*da0073e9SAndroid Build Coastguard Worker out = torch.jit.annotate(List[Tensor], [mod(x) for mod in self.module_list]) 11658*da0073e9SAndroid Build Coastguard Worker return out 11659*da0073e9SAndroid Build Coastguard Worker 11660*da0073e9SAndroid Build Coastguard Worker mod = M(nn.ModuleList([Inner(), Inner()])) 11661*da0073e9SAndroid Build Coastguard Worker self.checkModule(mod, (torch.tensor(3),)) 11662*da0073e9SAndroid Build Coastguard Worker 11663*da0073e9SAndroid Build Coastguard Worker mod = M(nn.ModuleList([])) 11664*da0073e9SAndroid Build Coastguard Worker torch.jit.script(mod) 11665*da0073e9SAndroid Build Coastguard Worker 11666*da0073e9SAndroid Build Coastguard Worker class M2(M): 11667*da0073e9SAndroid Build Coastguard Worker def __init__(self, mod_list): 11668*da0073e9SAndroid Build Coastguard Worker super().__init__(mod_list) 11669*da0073e9SAndroid Build Coastguard Worker 11670*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 11671*da0073e9SAndroid Build Coastguard Worker out = [mod(x) for mod in self.module_list] 11672*da0073e9SAndroid Build Coastguard Worker return out 11673*da0073e9SAndroid Build Coastguard Worker 11674*da0073e9SAndroid Build Coastguard Worker mod = M2(nn.ModuleList([Inner(), Inner()])) 11675*da0073e9SAndroid Build Coastguard Worker self.checkModule(mod, (torch.tensor(3),)) 11676*da0073e9SAndroid Build Coastguard Worker 11677*da0073e9SAndroid Build Coastguard Worker mod = M2(nn.ModuleList([])) 11678*da0073e9SAndroid Build Coastguard Worker # defaults to List of Tensor for empty modulelist 11679*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.jit.script(mod)(torch.tensor(.5)), []) 11680*da0073e9SAndroid Build Coastguard Worker 11681*da0073e9SAndroid Build Coastguard Worker def bad_type_annotation(): 11682*da0073e9SAndroid Build Coastguard Worker out = torch.jit.annotate(int, [x for x in [1, 2, 3]]) # noqa: C416 11683*da0073e9SAndroid Build Coastguard Worker return out 11684*da0073e9SAndroid Build Coastguard Worker 11685*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(Exception, "Expected an annotation" 11686*da0073e9SAndroid Build Coastguard Worker " of type List"): 11687*da0073e9SAndroid Build Coastguard Worker torch.jit.script(bad_type_annotation) 11688*da0073e9SAndroid Build Coastguard Worker 11689*da0073e9SAndroid Build Coastguard Worker def test_list_comprehension_variable_write(self): 11690*da0073e9SAndroid Build Coastguard Worker # i in comprehension doesn't write to function scope 11691*da0073e9SAndroid Build Coastguard Worker def foo(): 11692*da0073e9SAndroid Build Coastguard Worker i = 1 11693*da0073e9SAndroid Build Coastguard Worker x = [i if i != 5 else 3 for i in range(7)] # noqa: C416 11694*da0073e9SAndroid Build Coastguard Worker return i, x 11695*da0073e9SAndroid Build Coastguard Worker 11696*da0073e9SAndroid Build Coastguard Worker self.assertEqual(foo(), torch.jit.script(foo)()) 11697*da0073e9SAndroid Build Coastguard Worker 11698*da0073e9SAndroid Build Coastguard Worker def test_for_in_zip(self): 11699*da0073e9SAndroid Build Coastguard Worker def fn(x, y): 11700*da0073e9SAndroid Build Coastguard Worker # type: (List[int], List[int]) -> int 11701*da0073e9SAndroid Build Coastguard Worker sum = 0 11702*da0073e9SAndroid Build Coastguard Worker for (i, j) in zip(x, y): 11703*da0073e9SAndroid Build Coastguard Worker sum += i * j 11704*da0073e9SAndroid Build Coastguard Worker 11705*da0073e9SAndroid Build Coastguard Worker return sum 11706*da0073e9SAndroid Build Coastguard Worker 11707*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, ([1, 2, 3, 4, 5], [2, 3, 4, 5, 6])) 11708*da0073e9SAndroid Build Coastguard Worker 11709*da0073e9SAndroid Build Coastguard Worker def fn_multi_inputs(x, y, z): 11710*da0073e9SAndroid Build Coastguard Worker # type: (List[int], List[int], List[int]) -> int 11711*da0073e9SAndroid Build Coastguard Worker sum = 0 11712*da0073e9SAndroid Build Coastguard Worker for (i, j, k) in zip(x, y, z): 11713*da0073e9SAndroid Build Coastguard Worker sum += i * j * k 11714*da0073e9SAndroid Build Coastguard Worker 11715*da0073e9SAndroid Build Coastguard Worker return sum 11716*da0073e9SAndroid Build Coastguard Worker 11717*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn_multi_inputs, ([1, 2, 3, 4], [2, 3, 4, 5], [3, 4, 5, 6])) 11718*da0073e9SAndroid Build Coastguard Worker 11719*da0073e9SAndroid Build Coastguard Worker def fn_nested_zip(x, y, z): 11720*da0073e9SAndroid Build Coastguard Worker # type: (List[int], List[int], List[int]) -> int 11721*da0073e9SAndroid Build Coastguard Worker sum = 0 11722*da0073e9SAndroid Build Coastguard Worker for (i, (j, k)) in zip(x, zip(y, z)): 11723*da0073e9SAndroid Build Coastguard Worker sum += i * j * k 11724*da0073e9SAndroid Build Coastguard Worker 11725*da0073e9SAndroid Build Coastguard Worker return sum 11726*da0073e9SAndroid Build Coastguard Worker 11727*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn_multi_inputs, ([1, 2, 3, 4], [2, 3, 4, 5], [3, 4, 5, 6])) 11728*da0073e9SAndroid Build Coastguard Worker 11729*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r'zip expected at least 1 arguments, got 0'): 11730*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 11731*da0073e9SAndroid Build Coastguard Worker def zip_no_arg(x): 11732*da0073e9SAndroid Build Coastguard Worker # type: (List[int]) -> int 11733*da0073e9SAndroid Build Coastguard Worker sum = 0 11734*da0073e9SAndroid Build Coastguard Worker for _ in zip(): 11735*da0073e9SAndroid Build Coastguard Worker sum += 1 11736*da0073e9SAndroid Build Coastguard Worker 11737*da0073e9SAndroid Build Coastguard Worker return sum 11738*da0073e9SAndroid Build Coastguard Worker 11739*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r'too many values to unpack: need 2 but found 3'): 11740*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 11741*da0073e9SAndroid Build Coastguard Worker def fn_nested_zip_wrong_target_assign(x, y, z): 11742*da0073e9SAndroid Build Coastguard Worker # type: (List[int], List[int], List[int]) -> int 11743*da0073e9SAndroid Build Coastguard Worker sum = 0 11744*da0073e9SAndroid Build Coastguard Worker for (i, (j, k)) in zip(x, y, z): 11745*da0073e9SAndroid Build Coastguard Worker sum += i * j * k 11746*da0073e9SAndroid Build Coastguard Worker 11747*da0073e9SAndroid Build Coastguard Worker return sum 11748*da0073e9SAndroid Build Coastguard Worker 11749*da0073e9SAndroid Build Coastguard Worker def test_for_in_zip_enumerate(self): 11750*da0073e9SAndroid Build Coastguard Worker def fn_zip_enumerate(x, y): 11751*da0073e9SAndroid Build Coastguard Worker # type: (List[int], List[int]) -> int 11752*da0073e9SAndroid Build Coastguard Worker sum = 0 11753*da0073e9SAndroid Build Coastguard Worker for (i, (j, v), k) in zip(x, enumerate(y), range(0, 100)): 11754*da0073e9SAndroid Build Coastguard Worker sum += i * j * v * k 11755*da0073e9SAndroid Build Coastguard Worker 11756*da0073e9SAndroid Build Coastguard Worker return sum 11757*da0073e9SAndroid Build Coastguard Worker 11758*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn_zip_enumerate, ([1, 2, 3, 4], [2, 3, 4, 5])) 11759*da0073e9SAndroid Build Coastguard Worker 11760*da0073e9SAndroid Build Coastguard Worker def fn_enumerate_zip(x, y): 11761*da0073e9SAndroid Build Coastguard Worker # type: (List[int], List[int]) -> int 11762*da0073e9SAndroid Build Coastguard Worker sum = 0 11763*da0073e9SAndroid Build Coastguard Worker for (i, (j, v)) in enumerate(zip(x, y)): 11764*da0073e9SAndroid Build Coastguard Worker sum += i * j * v 11765*da0073e9SAndroid Build Coastguard Worker 11766*da0073e9SAndroid Build Coastguard Worker return sum 11767*da0073e9SAndroid Build Coastguard Worker 11768*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn_enumerate_zip, ([1, 2, 3, 4], [2, 3, 4, 5])) 11769*da0073e9SAndroid Build Coastguard Worker 11770*da0073e9SAndroid Build Coastguard Worker def test_for_in_tensors(self): 11771*da0073e9SAndroid Build Coastguard Worker def test_sizes(x): 11772*da0073e9SAndroid Build Coastguard Worker sumz = 0 11773*da0073e9SAndroid Build Coastguard Worker for s in x: 11774*da0073e9SAndroid Build Coastguard Worker sumz += 1 11775*da0073e9SAndroid Build Coastguard Worker return sumz 11776*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_sizes, (torch.rand(5, 4, 3, 2, 1),)) 11777*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_sizes, (torch.rand(777),)) 11778*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_sizes, (torch.rand(0),)) 11779*da0073e9SAndroid Build Coastguard Worker 11780*da0073e9SAndroid Build Coastguard Worker def test_for_in_tensors_rank0(self): 11781*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "of a 0-d tensor"): 11782*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 11783*da0073e9SAndroid Build Coastguard Worker def test_sizes(x): 11784*da0073e9SAndroid Build Coastguard Worker sumz = 0 11785*da0073e9SAndroid Build Coastguard Worker for s in x: 11786*da0073e9SAndroid Build Coastguard Worker sumz += 1 11787*da0073e9SAndroid Build Coastguard Worker return sumz 11788*da0073e9SAndroid Build Coastguard Worker 11789*da0073e9SAndroid Build Coastguard Worker test_sizes(torch.tensor(1)) 11790*da0073e9SAndroid Build Coastguard Worker 11791*da0073e9SAndroid Build Coastguard Worker def test_for_in_tensors_fail_scalar(self): 11792*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "'float' object is not iterable"): 11793*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 11794*da0073e9SAndroid Build Coastguard Worker def test_sizes(x): 11795*da0073e9SAndroid Build Coastguard Worker # type: (float) -> int 11796*da0073e9SAndroid Build Coastguard Worker sumz = 0 11797*da0073e9SAndroid Build Coastguard Worker for s in x: 11798*da0073e9SAndroid Build Coastguard Worker sumz += 1 11799*da0073e9SAndroid Build Coastguard Worker return sumz 11800*da0073e9SAndroid Build Coastguard Worker 11801*da0073e9SAndroid Build Coastguard Worker test_sizes(0.0) 11802*da0073e9SAndroid Build Coastguard Worker 11803*da0073e9SAndroid Build Coastguard Worker def test_for_in_tensors_nested(self): 11804*da0073e9SAndroid Build Coastguard Worker def test_sizes(x): 11805*da0073e9SAndroid Build Coastguard Worker sumz = 0 11806*da0073e9SAndroid Build Coastguard Worker for n in x: 11807*da0073e9SAndroid Build Coastguard Worker for t in n: 11808*da0073e9SAndroid Build Coastguard Worker sumz += 1 11809*da0073e9SAndroid Build Coastguard Worker return sumz 11810*da0073e9SAndroid Build Coastguard Worker 11811*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_sizes, (torch.rand(5, 4, 3, 2, 1),)) 11812*da0073e9SAndroid Build Coastguard Worker 11813*da0073e9SAndroid Build Coastguard Worker # to avoid defining sum_list in multiple tests 11814*da0073e9SAndroid Build Coastguard Worker def get_sum_list_fn(self): 11815*da0073e9SAndroid Build Coastguard Worker def sum_list(a): 11816*da0073e9SAndroid Build Coastguard Worker # type: (List[int]) -> int 11817*da0073e9SAndroid Build Coastguard Worker sum = 0 11818*da0073e9SAndroid Build Coastguard Worker for i in a: 11819*da0073e9SAndroid Build Coastguard Worker sum += i 11820*da0073e9SAndroid Build Coastguard Worker 11821*da0073e9SAndroid Build Coastguard Worker return sum 11822*da0073e9SAndroid Build Coastguard Worker 11823*da0073e9SAndroid Build Coastguard Worker return sum_list 11824*da0073e9SAndroid Build Coastguard Worker 11825*da0073e9SAndroid Build Coastguard Worker def test_sum_list_diff_elms(self): 11826*da0073e9SAndroid Build Coastguard Worker self.checkScript(self.get_sum_list_fn(), ([1, 2, 3, 4, 5],)) 11827*da0073e9SAndroid Build Coastguard Worker 11828*da0073e9SAndroid Build Coastguard Worker def test_sum_list_empty(self): 11829*da0073e9SAndroid Build Coastguard Worker self.checkScript(self.get_sum_list_fn(), ([],)) 11830*da0073e9SAndroid Build Coastguard Worker 11831*da0073e9SAndroid Build Coastguard Worker def test_sum_list_one(self): 11832*da0073e9SAndroid Build Coastguard Worker self.checkScript(self.get_sum_list_fn(), ([1],)) 11833*da0073e9SAndroid Build Coastguard Worker 11834*da0073e9SAndroid Build Coastguard Worker def test_sum_list_literal(self): 11835*da0073e9SAndroid Build Coastguard Worker 11836*da0073e9SAndroid Build Coastguard Worker def sum_list(): 11837*da0073e9SAndroid Build Coastguard Worker # type: () -> int 11838*da0073e9SAndroid Build Coastguard Worker sum = 0 11839*da0073e9SAndroid Build Coastguard Worker for i in [1, 2, 3, 4, 5]: 11840*da0073e9SAndroid Build Coastguard Worker sum += i 11841*da0073e9SAndroid Build Coastguard Worker 11842*da0073e9SAndroid Build Coastguard Worker return sum 11843*da0073e9SAndroid Build Coastguard Worker 11844*da0073e9SAndroid Build Coastguard Worker self.checkScript(sum_list, ()) 11845*da0073e9SAndroid Build Coastguard Worker 11846*da0073e9SAndroid Build Coastguard Worker def test_sum_list_wrong_type(self): 11847*da0073e9SAndroid Build Coastguard Worker 11848*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "'int' object is not iterable"): 11849*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 11850*da0073e9SAndroid Build Coastguard Worker def sum_list(a): 11851*da0073e9SAndroid Build Coastguard Worker # type: (int) -> int 11852*da0073e9SAndroid Build Coastguard Worker sum = 0 11853*da0073e9SAndroid Build Coastguard Worker for i in a: # noqa: T484 11854*da0073e9SAndroid Build Coastguard Worker sum += i 11855*da0073e9SAndroid Build Coastguard Worker 11856*da0073e9SAndroid Build Coastguard Worker return sum 11857*da0073e9SAndroid Build Coastguard Worker 11858*da0073e9SAndroid Build Coastguard Worker sum_list(1) 11859*da0073e9SAndroid Build Coastguard Worker 11860*da0073e9SAndroid Build Coastguard Worker def test_list_iterables(self): 11861*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'List of iterables is not supported currently'): 11862*da0073e9SAndroid Build Coastguard Worker cu = torch.jit.CompilationUnit(''' 11863*da0073e9SAndroid Build Coastguard Worker def list_iterables(x): 11864*da0073e9SAndroid Build Coastguard Worker for i, j in [2, 3, 4], [5, 6, 7]: 11865*da0073e9SAndroid Build Coastguard Worker x += i 11866*da0073e9SAndroid Build Coastguard Worker x += j 11867*da0073e9SAndroid Build Coastguard Worker return x 11868*da0073e9SAndroid Build Coastguard Worker ''') 11869*da0073e9SAndroid Build Coastguard Worker 11870*da0073e9SAndroid Build Coastguard Worker def test_for_in_string(self): 11871*da0073e9SAndroid Build Coastguard Worker def test_strings(x): 11872*da0073e9SAndroid Build Coastguard Worker # type: (str) -> str 11873*da0073e9SAndroid Build Coastguard Worker reverse = "" 11874*da0073e9SAndroid Build Coastguard Worker for c in x: 11875*da0073e9SAndroid Build Coastguard Worker reverse = c + reverse 11876*da0073e9SAndroid Build Coastguard Worker return reverse 11877*da0073e9SAndroid Build Coastguard Worker 11878*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_strings, ("hello",)) 11879*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_strings, ("",)) 11880*da0073e9SAndroid Build Coastguard Worker 11881*da0073e9SAndroid Build Coastguard Worker def test_list_strings(x): 11882*da0073e9SAndroid Build Coastguard Worker # type: (List[str]) -> str 11883*da0073e9SAndroid Build Coastguard Worker result = "" 11884*da0073e9SAndroid Build Coastguard Worker for sub_str in x: 11885*da0073e9SAndroid Build Coastguard Worker result += sub_str 11886*da0073e9SAndroid Build Coastguard Worker return result 11887*da0073e9SAndroid Build Coastguard Worker 11888*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_list_strings, (["hello", "world"],)) 11889*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_list_strings, (["hello", " ", "world", ""],)) 11890*da0073e9SAndroid Build Coastguard Worker 11891*da0073e9SAndroid Build Coastguard Worker def test_for_in_dict(self): 11892*da0073e9SAndroid Build Coastguard Worker def test_dicts(x): 11893*da0073e9SAndroid Build Coastguard Worker # type: (Dict[str, int]) -> int 11894*da0073e9SAndroid Build Coastguard Worker sum = 0 11895*da0073e9SAndroid Build Coastguard Worker for key in x: 11896*da0073e9SAndroid Build Coastguard Worker sum += x[key] 11897*da0073e9SAndroid Build Coastguard Worker return sum 11898*da0073e9SAndroid Build Coastguard Worker 11899*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_dicts, ({"a": 1, "b": 2, "c": 3},)) 11900*da0073e9SAndroid Build Coastguard Worker 11901*da0073e9SAndroid Build Coastguard Worker def test_dict_keys_values(x): 11902*da0073e9SAndroid Build Coastguard Worker # type: (Dict[str, int]) -> Tuple[str, int] 11903*da0073e9SAndroid Build Coastguard Worker key_str = "" 11904*da0073e9SAndroid Build Coastguard Worker sum = 0 11905*da0073e9SAndroid Build Coastguard Worker for key in x.keys(): 11906*da0073e9SAndroid Build Coastguard Worker key_str += key 11907*da0073e9SAndroid Build Coastguard Worker for val in x.values(): 11908*da0073e9SAndroid Build Coastguard Worker sum += val 11909*da0073e9SAndroid Build Coastguard Worker return key_str, sum 11910*da0073e9SAndroid Build Coastguard Worker 11911*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_dicts, ({"a": 1, "b": 2, "c": 3},)) 11912*da0073e9SAndroid Build Coastguard Worker 11913*da0073e9SAndroid Build Coastguard Worker def test_for_tuple_unpack(self): 11914*da0073e9SAndroid Build Coastguard Worker def for_tuple_unpack(x, y): 11915*da0073e9SAndroid Build Coastguard Worker for i, j in [[3, 4], [5, 6], [7, 8]]: 11916*da0073e9SAndroid Build Coastguard Worker x += i 11917*da0073e9SAndroid Build Coastguard Worker y += j 11918*da0073e9SAndroid Build Coastguard Worker return x, y 11919*da0073e9SAndroid Build Coastguard Worker 11920*da0073e9SAndroid Build Coastguard Worker self.checkScript(for_tuple_unpack, (torch.tensor(3), torch.tensor(5))) 11921*da0073e9SAndroid Build Coastguard Worker 11922*da0073e9SAndroid Build Coastguard Worker def nested_tuple_unpack(x, y): 11923*da0073e9SAndroid Build Coastguard Worker # type: (List[int], List[int]) -> int 11924*da0073e9SAndroid Build Coastguard Worker sum = 0 11925*da0073e9SAndroid Build Coastguard Worker for i, (j, k), v in zip(x, enumerate(x), y): 11926*da0073e9SAndroid Build Coastguard Worker sum += i + j + k + v 11927*da0073e9SAndroid Build Coastguard Worker return sum 11928*da0073e9SAndroid Build Coastguard Worker 11929*da0073e9SAndroid Build Coastguard Worker self.checkScript(nested_tuple_unpack, ([1, 3, 5], [2, 4, 6])) 11930*da0073e9SAndroid Build Coastguard Worker 11931*da0073e9SAndroid Build Coastguard Worker def test_for_tuple_assign(self): 11932*da0073e9SAndroid Build Coastguard Worker def test_simple_assign(x): 11933*da0073e9SAndroid Build Coastguard Worker # type: (Tuple[int, float]) -> float 11934*da0073e9SAndroid Build Coastguard Worker sum = 0.0 11935*da0073e9SAndroid Build Coastguard Worker for a in x: 11936*da0073e9SAndroid Build Coastguard Worker sum += float(a) 11937*da0073e9SAndroid Build Coastguard Worker return sum 11938*da0073e9SAndroid Build Coastguard Worker 11939*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_simple_assign, ((1, 2.5),)) 11940*da0073e9SAndroid Build Coastguard Worker 11941*da0073e9SAndroid Build Coastguard Worker def test_tuple_assign(x): 11942*da0073e9SAndroid Build Coastguard Worker # type: (Tuple[Tuple[int, int], Tuple[int, int]]) -> int 11943*da0073e9SAndroid Build Coastguard Worker sum = 0 11944*da0073e9SAndroid Build Coastguard Worker for a in x: 11945*da0073e9SAndroid Build Coastguard Worker sum += a[0] 11946*da0073e9SAndroid Build Coastguard Worker sum += a[1] 11947*da0073e9SAndroid Build Coastguard Worker return sum 11948*da0073e9SAndroid Build Coastguard Worker 11949*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_tuple_assign, (((1, 2), (4, 7)), )) 11950*da0073e9SAndroid Build Coastguard Worker 11951*da0073e9SAndroid Build Coastguard Worker def test_single_starred_lhs(self): 11952*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'A Starred expression may only appear on the lhs within the presence' 11953*da0073e9SAndroid Build Coastguard Worker ' of another non-starred expression'): 11954*da0073e9SAndroid Build Coastguard Worker cu = torch.jit.CompilationUnit(''' 11955*da0073e9SAndroid Build Coastguard Worker def single_starred_lhs(x): 11956*da0073e9SAndroid Build Coastguard Worker a = (x, x, x) 11957*da0073e9SAndroid Build Coastguard Worker *b, = a 11958*da0073e9SAndroid Build Coastguard Worker return b 11959*da0073e9SAndroid Build Coastguard Worker ''') 11960*da0073e9SAndroid Build Coastguard Worker 11961*da0073e9SAndroid Build Coastguard Worker def test_singleton_tuple_unpack(self): 11962*da0073e9SAndroid Build Coastguard Worker def foo(a): 11963*da0073e9SAndroid Build Coastguard Worker b, = (a,) 11964*da0073e9SAndroid Build Coastguard Worker return b + 1 11965*da0073e9SAndroid Build Coastguard Worker self.checkScript(foo, (torch.rand(3),)) 11966*da0073e9SAndroid Build Coastguard Worker 11967*da0073e9SAndroid Build Coastguard Worker def test_tuple_assignments(self): 11968*da0073e9SAndroid Build Coastguard Worker def var_tuple_assign(x, y): 11969*da0073e9SAndroid Build Coastguard Worker # type: (Tuple[Tensor, Tensor], Tensor) -> Tensor 11970*da0073e9SAndroid Build Coastguard Worker (a, b), c = x, y 11971*da0073e9SAndroid Build Coastguard Worker return a + b + c 11972*da0073e9SAndroid Build Coastguard Worker 11973*da0073e9SAndroid Build Coastguard Worker tuple_inputs = (torch.randn(1, 4), torch.randn(3, 4)) 11974*da0073e9SAndroid Build Coastguard Worker self.checkScript(var_tuple_assign, (tuple_inputs, torch.randn(3, 4))) 11975*da0073e9SAndroid Build Coastguard Worker 11976*da0073e9SAndroid Build Coastguard Worker def nested_tuple_assign(x, y, z): 11977*da0073e9SAndroid Build Coastguard Worker # type: (int, Tuple[int, Tuple[int, int]], Tuple[int, int]) -> int 11978*da0073e9SAndroid Build Coastguard Worker a, (b, (c, d)), (e, f) = x, y, z 11979*da0073e9SAndroid Build Coastguard Worker return a + b + c + d + e + f 11980*da0073e9SAndroid Build Coastguard Worker 11981*da0073e9SAndroid Build Coastguard Worker self.checkScript(nested_tuple_assign, ((1, (2, (3, 4)), (5, 6)))) 11982*da0073e9SAndroid Build Coastguard Worker 11983*da0073e9SAndroid Build Coastguard Worker def subscript_tuple_assign(a, x, i): 11984*da0073e9SAndroid Build Coastguard Worker # type: (List[int], Tensor, int) -> Tuple[int, Tensor, int] 11985*da0073e9SAndroid Build Coastguard Worker a[i], (x[i], b) = 1, (2, 3) 11986*da0073e9SAndroid Build Coastguard Worker return a[i] + 1, x + 5, b 11987*da0073e9SAndroid Build Coastguard Worker 11988*da0073e9SAndroid Build Coastguard Worker self.checkScript(subscript_tuple_assign, ([12, 7, 9, 11], torch.tensor((3, 13, 17)), 0)) 11989*da0073e9SAndroid Build Coastguard Worker 11990*da0073e9SAndroid Build Coastguard Worker def star_tuple_assign(): 11991*da0073e9SAndroid Build Coastguard Worker # type: () -> Tuple[int, int, Tuple[int, int], Tuple[int, int]] 11992*da0073e9SAndroid Build Coastguard Worker a, (b, *c), *d = 1, (2, 3, 4), 5, 6 11993*da0073e9SAndroid Build Coastguard Worker return a, b, c, d 11994*da0073e9SAndroid Build Coastguard Worker 11995*da0073e9SAndroid Build Coastguard Worker self.checkScript(star_tuple_assign, ()) 11996*da0073e9SAndroid Build Coastguard Worker 11997*da0073e9SAndroid Build Coastguard Worker def subscript_tuple_augmented_assign(a): 11998*da0073e9SAndroid Build Coastguard Worker # type: (Tuple[int, int]) -> Tuple[int, int] 11999*da0073e9SAndroid Build Coastguard Worker a[0] += 1 12000*da0073e9SAndroid Build Coastguard Worker return a 12001*da0073e9SAndroid Build Coastguard Worker 12002*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'does not support augmented assign'): 12003*da0073e9SAndroid Build Coastguard Worker scripted_aug_assign = torch.jit.script(subscript_tuple_augmented_assign) 12004*da0073e9SAndroid Build Coastguard Worker 12005*da0073e9SAndroid Build Coastguard Worker class AttrTupleAssignmentTestClass: 12006*da0073e9SAndroid Build Coastguard Worker def __init__(self, a: int, b: int): 12007*da0073e9SAndroid Build Coastguard Worker self.a = a 12008*da0073e9SAndroid Build Coastguard Worker self.b = b 12009*da0073e9SAndroid Build Coastguard Worker 12010*da0073e9SAndroid Build Coastguard Worker def set_ab(self, a: int, b: int): 12011*da0073e9SAndroid Build Coastguard Worker self.a, self.b = (a, b) 12012*da0073e9SAndroid Build Coastguard Worker 12013*da0073e9SAndroid Build Coastguard Worker def get(self) -> Tuple[int, int]: 12014*da0073e9SAndroid Build Coastguard Worker return (self.a, self.b) 12015*da0073e9SAndroid Build Coastguard Worker 12016*da0073e9SAndroid Build Coastguard Worker make_global(AttrTupleAssignmentTestClass) 12017*da0073e9SAndroid Build Coastguard Worker 12018*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 12019*da0073e9SAndroid Build Coastguard Worker def attr_tuple_assignment(o: AttrTupleAssignmentTestClass, a: int, b: int): 12020*da0073e9SAndroid Build Coastguard Worker o.set_ab(a, b) 12021*da0073e9SAndroid Build Coastguard Worker return o 12022*da0073e9SAndroid Build Coastguard Worker 12023*da0073e9SAndroid Build Coastguard Worker o = AttrTupleAssignmentTestClass(1, 2) 12024*da0073e9SAndroid Build Coastguard Worker self.assertEqual(attr_tuple_assignment(o, 3, 4).get(), (3, 4)) 12025*da0073e9SAndroid Build Coastguard Worker 12026*da0073e9SAndroid Build Coastguard Worker def test_multiple_assign(self): 12027*da0073e9SAndroid Build Coastguard Worker def test(): 12028*da0073e9SAndroid Build Coastguard Worker a = b, c = d, f = (1, 1) 12029*da0073e9SAndroid Build Coastguard Worker 12030*da0073e9SAndroid Build Coastguard Worker # side effect 12031*da0073e9SAndroid Build Coastguard Worker ten = torch.tensor(1) 12032*da0073e9SAndroid Build Coastguard Worker ten1 = ten2 = ten.add_(1) 12033*da0073e9SAndroid Build Coastguard Worker 12034*da0073e9SAndroid Build Coastguard Worker # ordering 12035*da0073e9SAndroid Build Coastguard Worker x = 1 12036*da0073e9SAndroid Build Coastguard Worker y = 3 12037*da0073e9SAndroid Build Coastguard Worker x, y = y, x + y 12038*da0073e9SAndroid Build Coastguard Worker 12039*da0073e9SAndroid Build Coastguard Worker return a, b, c, d, f, ten, ten1, ten2, x, y 12040*da0073e9SAndroid Build Coastguard Worker 12041*da0073e9SAndroid Build Coastguard Worker self.checkScript(test, ()) 12042*da0073e9SAndroid Build Coastguard Worker 12043*da0073e9SAndroid Build Coastguard Worker def test_multi_reduction(self): 12044*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 12045*da0073e9SAndroid Build Coastguard Worker RuntimeError, 12046*da0073e9SAndroid Build Coastguard Worker 'augmented assignment can only have one LHS expression'): 12047*da0073e9SAndroid Build Coastguard Worker cu = torch.jit.CompilationUnit(''' 12048*da0073e9SAndroid Build Coastguard Worker def multi_reduction(x): 12049*da0073e9SAndroid Build Coastguard Worker a, b += x 12050*da0073e9SAndroid Build Coastguard Worker return a, b 12051*da0073e9SAndroid Build Coastguard Worker ''') 12052*da0073e9SAndroid Build Coastguard Worker 12053*da0073e9SAndroid Build Coastguard Worker def test_invalid_call_arguments(self): 12054*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'but instead found type '): 12055*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 12056*da0073e9SAndroid Build Coastguard Worker def invalid_call_arguments(x): 12057*da0073e9SAndroid Build Coastguard Worker return torch.unsqueeze(3, 4, 5, 6, 7, 8) 12058*da0073e9SAndroid Build Coastguard Worker 12059*da0073e9SAndroid Build Coastguard Worker def test_invalid_lhs_assignment(self): 12060*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'unexpected expression'): 12061*da0073e9SAndroid Build Coastguard Worker cu = torch.jit.CompilationUnit(''' 12062*da0073e9SAndroid Build Coastguard Worker def invalid_lhs_assignment(x): 12063*da0073e9SAndroid Build Coastguard Worker x + 1 = x 12064*da0073e9SAndroid Build Coastguard Worker return x 12065*da0073e9SAndroid Build Coastguard Worker ''') 12066*da0073e9SAndroid Build Coastguard Worker 12067*da0073e9SAndroid Build Coastguard Worker def test_multi_starred_expr_lhs(self): 12068*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'Only one starred expression is allowed on the lhs'): 12069*da0073e9SAndroid Build Coastguard Worker cu = torch.jit.CompilationUnit(''' 12070*da0073e9SAndroid Build Coastguard Worker def multi_starred_expr_lhs(): 12071*da0073e9SAndroid Build Coastguard Worker a, *b, *c = [1, 2, 3, 4, 5, 6] 12072*da0073e9SAndroid Build Coastguard Worker return a 12073*da0073e9SAndroid Build Coastguard Worker ''') 12074*da0073e9SAndroid Build Coastguard Worker 12075*da0073e9SAndroid Build Coastguard Worker def test_pack_tuple_into_non_var(self): 12076*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'Cannot pack a tuple into a non-variable'): 12077*da0073e9SAndroid Build Coastguard Worker cu = torch.jit.CompilationUnit(''' 12078*da0073e9SAndroid Build Coastguard Worker def pack_tuple_into_non_var(x): 12079*da0073e9SAndroid Build Coastguard Worker a, *1 = (3, 4, 5) 12080*da0073e9SAndroid Build Coastguard Worker return x 12081*da0073e9SAndroid Build Coastguard Worker ''') 12082*da0073e9SAndroid Build Coastguard Worker 12083*da0073e9SAndroid Build Coastguard Worker def test_print_kwargs(self): 12084*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'print doesn\'t accept any keyword arguments'): 12085*da0073e9SAndroid Build Coastguard Worker cu = torch.jit.CompilationUnit(''' 12086*da0073e9SAndroid Build Coastguard Worker def print_kwargs(x): 12087*da0073e9SAndroid Build Coastguard Worker print(x, flush=True) 12088*da0073e9SAndroid Build Coastguard Worker return x 12089*da0073e9SAndroid Build Coastguard Worker ''') 12090*da0073e9SAndroid Build Coastguard Worker 12091*da0073e9SAndroid Build Coastguard Worker def test_builtin_use_as_value(self): 12092*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'builtin cannot be used as a value'): 12093*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 12094*da0073e9SAndroid Build Coastguard Worker def builtin_use_as_value(x): 12095*da0073e9SAndroid Build Coastguard Worker return x.unsqueeze 12096*da0073e9SAndroid Build Coastguard Worker 12097*da0073e9SAndroid Build Coastguard Worker def test_wrong_use_as_tuple(self): 12098*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'cannot be used as a tuple'): 12099*da0073e9SAndroid Build Coastguard Worker def test_fn(): 12100*da0073e9SAndroid Build Coastguard Worker return 3 12101*da0073e9SAndroid Build Coastguard Worker 12102*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 12103*da0073e9SAndroid Build Coastguard Worker def wrong_use_as_tuple(self): 12104*da0073e9SAndroid Build Coastguard Worker a, b = test_fn 12105*da0073e9SAndroid Build Coastguard Worker return a 12106*da0073e9SAndroid Build Coastguard Worker 12107*da0073e9SAndroid Build Coastguard Worker def test_wrong_attr_lookup(self): 12108*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'attribute lookup is not defined on builtin'): 12109*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 12110*da0073e9SAndroid Build Coastguard Worker def wrong_attr_lookup(self, x): 12111*da0073e9SAndroid Build Coastguard Worker a = x.unsqueeze.myattr 12112*da0073e9SAndroid Build Coastguard Worker return a 12113*da0073e9SAndroid Build Coastguard Worker 12114*da0073e9SAndroid Build Coastguard Worker def test_wrong_use_as_callable(self): 12115*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'cannot call a value'): 12116*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 12117*da0073e9SAndroid Build Coastguard Worker def wrong_use_as_callable(x): 12118*da0073e9SAndroid Build Coastguard Worker return x(3, 4, 5) 12119*da0073e9SAndroid Build Coastguard Worker 12120*da0073e9SAndroid Build Coastguard Worker def test_python_val_doesnt_have_attr(self): 12121*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'object has no attribute abcd'): 12122*da0073e9SAndroid Build Coastguard Worker 12123*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 12124*da0073e9SAndroid Build Coastguard Worker def python_val_doesnt_have_attr(): 12125*da0073e9SAndroid Build Coastguard Worker # this has to be a module otherwise attr lookup would not be 12126*da0073e9SAndroid Build Coastguard Worker # allowed in the first place 12127*da0073e9SAndroid Build Coastguard Worker return shutil.abcd 12128*da0073e9SAndroid Build Coastguard Worker 12129*da0073e9SAndroid Build Coastguard Worker def test_wrong_module_attr_lookup(self): 12130*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'python value of type \'type\' cannot be used as a value'): 12131*da0073e9SAndroid Build Coastguard Worker import io 12132*da0073e9SAndroid Build Coastguard Worker 12133*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 12134*da0073e9SAndroid Build Coastguard Worker def wrong_module_attr_lookup(): 12135*da0073e9SAndroid Build Coastguard Worker return io.BytesIO 12136*da0073e9SAndroid Build Coastguard Worker 12137*da0073e9SAndroid Build Coastguard Worker def test_wrong_method_call_inputs(self): 12138*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'Argument y not provided'): 12139*da0073e9SAndroid Build Coastguard Worker class SomeModule(torch.jit.ScriptModule): 12140*da0073e9SAndroid Build Coastguard Worker 12141*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 12142*da0073e9SAndroid Build Coastguard Worker def foo(self, x, y): 12143*da0073e9SAndroid Build Coastguard Worker return x 12144*da0073e9SAndroid Build Coastguard Worker 12145*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 12146*da0073e9SAndroid Build Coastguard Worker def forward(self, x, y): 12147*da0073e9SAndroid Build Coastguard Worker return self.foo(x) 12148*da0073e9SAndroid Build Coastguard Worker SomeModule() 12149*da0073e9SAndroid Build Coastguard Worker 12150*da0073e9SAndroid Build Coastguard Worker def test_single_starred_expr_for_loop(self): 12151*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'A Starred expression may only appear'): 12152*da0073e9SAndroid Build Coastguard Worker cu = torch.jit.CompilationUnit(''' 12153*da0073e9SAndroid Build Coastguard Worker def test(): 12154*da0073e9SAndroid Build Coastguard Worker x = 0 12155*da0073e9SAndroid Build Coastguard Worker for *a in [1, 2, 3]: 12156*da0073e9SAndroid Build Coastguard Worker x = x + 1 12157*da0073e9SAndroid Build Coastguard Worker return x 12158*da0073e9SAndroid Build Coastguard Worker ''') 12159*da0073e9SAndroid Build Coastguard Worker 12160*da0073e9SAndroid Build Coastguard Worker def test_call_ge(self): 12161*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'Expected at most 1 arguments but found 3'): 12162*da0073e9SAndroid Build Coastguard Worker @_trace(torch.zeros(1, 2, 3)) 12163*da0073e9SAndroid Build Coastguard Worker def foo(x): 12164*da0073e9SAndroid Build Coastguard Worker return x 12165*da0073e9SAndroid Build Coastguard Worker 12166*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 12167*da0073e9SAndroid Build Coastguard Worker def test_fn(): 12168*da0073e9SAndroid Build Coastguard Worker return foo(torch.full([1], 1), torch.full([1], 2), torch.full([1], 3)) 12169*da0073e9SAndroid Build Coastguard Worker 12170*da0073e9SAndroid Build Coastguard Worker def test_wrong_return_type(self): 12171*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'but instead got value of type tuple'): 12172*da0073e9SAndroid Build Coastguard Worker @torch.jit.ignore 12173*da0073e9SAndroid Build Coastguard Worker def somefunc(): 12174*da0073e9SAndroid Build Coastguard Worker # type: () -> Tuple[Tuple[Tensor, Tensor]] 12175*da0073e9SAndroid Build Coastguard Worker return torch.zeros(3, 4), torch.zeros(4, 5) # noqa: T484 12176*da0073e9SAndroid Build Coastguard Worker 12177*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 12178*da0073e9SAndroid Build Coastguard Worker def wrong_return_type(): 12179*da0073e9SAndroid Build Coastguard Worker return somefunc() 12180*da0073e9SAndroid Build Coastguard Worker wrong_return_type() 12181*da0073e9SAndroid Build Coastguard Worker 12182*da0073e9SAndroid Build Coastguard Worker # Tests for calling between different front-end modes 12183*da0073e9SAndroid Build Coastguard Worker def test_call_python_fn_from_tracing_fn(self): 12184*da0073e9SAndroid Build Coastguard Worker def python_fn(x): 12185*da0073e9SAndroid Build Coastguard Worker return torch.neg(x) 12186*da0073e9SAndroid Build Coastguard Worker 12187*da0073e9SAndroid Build Coastguard Worker @_trace(torch.rand(3, 4)) 12188*da0073e9SAndroid Build Coastguard Worker def traced_fn(x): 12189*da0073e9SAndroid Build Coastguard Worker return python_fn(x) + 1 12190*da0073e9SAndroid Build Coastguard Worker 12191*da0073e9SAndroid Build Coastguard Worker # The neg op in the python function should be properly inlined to the 12192*da0073e9SAndroid Build Coastguard Worker # graph 12193*da0073e9SAndroid Build Coastguard Worker FileCheck().check("aten::neg").run(str(traced_fn.graph)) 12194*da0073e9SAndroid Build Coastguard Worker 12195*da0073e9SAndroid Build Coastguard Worker def test_call_python_mod_from_tracing_fn(self): 12196*da0073e9SAndroid Build Coastguard Worker class PythonMod(torch.nn.Module): 12197*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 12198*da0073e9SAndroid Build Coastguard Worker super().__init__() 12199*da0073e9SAndroid Build Coastguard Worker self.param = torch.nn.Parameter(torch.rand(4, 3), requires_grad=False) 12200*da0073e9SAndroid Build Coastguard Worker 12201*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 12202*da0073e9SAndroid Build Coastguard Worker return torch.mm(x, self.param) 12203*da0073e9SAndroid Build Coastguard Worker 12204*da0073e9SAndroid Build Coastguard Worker pm = PythonMod() 12205*da0073e9SAndroid Build Coastguard Worker 12206*da0073e9SAndroid Build Coastguard Worker @_trace(torch.rand(3, 4)) 12207*da0073e9SAndroid Build Coastguard Worker def traced_fn(x): 12208*da0073e9SAndroid Build Coastguard Worker return pm(x) + 1.0 12209*da0073e9SAndroid Build Coastguard Worker 12210*da0073e9SAndroid Build Coastguard Worker # Note: the parameter self.param from the Python module is inlined 12211*da0073e9SAndroid Build Coastguard Worker # into the graph 12212*da0073e9SAndroid Build Coastguard Worker self.assertTrue(len(list(traced_fn.graph.inputs())) == 1) 12213*da0073e9SAndroid Build Coastguard Worker FileCheck().check("aten::mm").check("aten::add").run(str(traced_fn.graph)) 12214*da0073e9SAndroid Build Coastguard Worker 12215*da0073e9SAndroid Build Coastguard Worker @_tmp_donotuse_dont_inline_everything 12216*da0073e9SAndroid Build Coastguard Worker def test_call_traced_fn_from_tracing_fn(self): 12217*da0073e9SAndroid Build Coastguard Worker @_trace(torch.rand(3, 4)) 12218*da0073e9SAndroid Build Coastguard Worker def traced_fn1(x): 12219*da0073e9SAndroid Build Coastguard Worker return torch.neg(x) 12220*da0073e9SAndroid Build Coastguard Worker 12221*da0073e9SAndroid Build Coastguard Worker @_trace(torch.rand(3, 4)) 12222*da0073e9SAndroid Build Coastguard Worker def traced_fn(x): 12223*da0073e9SAndroid Build Coastguard Worker return traced_fn1(x) + 1 12224*da0073e9SAndroid Build Coastguard Worker 12225*da0073e9SAndroid Build Coastguard Worker FileCheck().check("traced_fn").check("prim::CallFunction").check("aten::add") \ 12226*da0073e9SAndroid Build Coastguard Worker .run(str(traced_fn.graph)) 12227*da0073e9SAndroid Build Coastguard Worker 12228*da0073e9SAndroid Build Coastguard Worker @unittest.skip("error in first class mode") 12229*da0073e9SAndroid Build Coastguard Worker def test_call_traced_mod_from_tracing_fn(self): 12230*da0073e9SAndroid Build Coastguard Worker class TracedModule(torch.nn.Module): 12231*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 12232*da0073e9SAndroid Build Coastguard Worker super().__init__() 12233*da0073e9SAndroid Build Coastguard Worker self.param = torch.nn.Parameter(torch.rand(4, 3), requires_grad=False) 12234*da0073e9SAndroid Build Coastguard Worker 12235*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 12236*da0073e9SAndroid Build Coastguard Worker return torch.mm(x, self.param) 12237*da0073e9SAndroid Build Coastguard Worker 12238*da0073e9SAndroid Build Coastguard Worker tm = torch.jit.trace(TracedModule(), torch.rand(3, 4)) 12239*da0073e9SAndroid Build Coastguard Worker 12240*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "must be registered as submodules"): 12241*da0073e9SAndroid Build Coastguard Worker @_trace(torch.rand(3, 4)) 12242*da0073e9SAndroid Build Coastguard Worker def traced_fn(x): 12243*da0073e9SAndroid Build Coastguard Worker return tm(x) + 1.0 12244*da0073e9SAndroid Build Coastguard Worker 12245*da0073e9SAndroid Build Coastguard Worker @_tmp_donotuse_dont_inline_everything 12246*da0073e9SAndroid Build Coastguard Worker def test_call_script_fn_from_tracing_fn(self): 12247*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 12248*da0073e9SAndroid Build Coastguard Worker def script_fn(x): 12249*da0073e9SAndroid Build Coastguard Worker return torch.neg(x) 12250*da0073e9SAndroid Build Coastguard Worker 12251*da0073e9SAndroid Build Coastguard Worker @_trace(torch.rand(3, 4)) 12252*da0073e9SAndroid Build Coastguard Worker def traced_fn(x): 12253*da0073e9SAndroid Build Coastguard Worker return script_fn(x) + 1 12254*da0073e9SAndroid Build Coastguard Worker 12255*da0073e9SAndroid Build Coastguard Worker FileCheck().check("prim::CallFunction").check("aten::add").run(str(traced_fn.graph)) 12256*da0073e9SAndroid Build Coastguard Worker 12257*da0073e9SAndroid Build Coastguard Worker @unittest.skip("error in first class mode") 12258*da0073e9SAndroid Build Coastguard Worker def test_call_script_mod_from_tracing_fn(self): 12259*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "must be registered as submodules"): 12260*da0073e9SAndroid Build Coastguard Worker class ScriptMod(torch.jit.ScriptModule): 12261*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 12262*da0073e9SAndroid Build Coastguard Worker super().__init__() 12263*da0073e9SAndroid Build Coastguard Worker self.param = torch.nn.Parameter(torch.rand(3, 4), requires_grad=False) 12264*da0073e9SAndroid Build Coastguard Worker 12265*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 12266*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 12267*da0073e9SAndroid Build Coastguard Worker for _i in range(4): 12268*da0073e9SAndroid Build Coastguard Worker x += self.param 12269*da0073e9SAndroid Build Coastguard Worker return x 12270*da0073e9SAndroid Build Coastguard Worker 12271*da0073e9SAndroid Build Coastguard Worker sm = ScriptMod() 12272*da0073e9SAndroid Build Coastguard Worker 12273*da0073e9SAndroid Build Coastguard Worker @_trace(torch.rand(3, 4)) 12274*da0073e9SAndroid Build Coastguard Worker def traced_fn(x): 12275*da0073e9SAndroid Build Coastguard Worker return sm(x) + 1.0 12276*da0073e9SAndroid Build Coastguard Worker 12277*da0073e9SAndroid Build Coastguard Worker 12278*da0073e9SAndroid Build Coastguard Worker def test_call_python_fn_from_traced_module(self): 12279*da0073e9SAndroid Build Coastguard Worker def python_fn(x): 12280*da0073e9SAndroid Build Coastguard Worker return torch.neg(x) 12281*da0073e9SAndroid Build Coastguard Worker 12282*da0073e9SAndroid Build Coastguard Worker class TracedModule(torch.nn.Module): 12283*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 12284*da0073e9SAndroid Build Coastguard Worker super().__init__() 12285*da0073e9SAndroid Build Coastguard Worker self.param = torch.nn.Parameter(torch.rand(4, 3)) 12286*da0073e9SAndroid Build Coastguard Worker 12287*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 12288*da0073e9SAndroid Build Coastguard Worker return torch.mm(python_fn(x), self.param) 12289*da0073e9SAndroid Build Coastguard Worker 12290*da0073e9SAndroid Build Coastguard Worker tm = torch.jit.trace(TracedModule(), torch.rand(3, 4)) 12291*da0073e9SAndroid Build Coastguard Worker 12292*da0073e9SAndroid Build Coastguard Worker # Note: parameter self.param from the traced module should appear as 12293*da0073e9SAndroid Build Coastguard Worker # an input to the graph and the neg op from the Python function should 12294*da0073e9SAndroid Build Coastguard Worker # be properly inlined 12295*da0073e9SAndroid Build Coastguard Worker self.assertTrue(len(list(tm.graph.inputs())) == 2) 12296*da0073e9SAndroid Build Coastguard Worker FileCheck().check("aten::neg").check("aten::mm").run(str(tm.graph)) 12297*da0073e9SAndroid Build Coastguard Worker 12298*da0073e9SAndroid Build Coastguard Worker def test_call_python_mod_from_traced_module(self): 12299*da0073e9SAndroid Build Coastguard Worker class PythonModule(torch.nn.Module): 12300*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 12301*da0073e9SAndroid Build Coastguard Worker super().__init__() 12302*da0073e9SAndroid Build Coastguard Worker self.param = torch.nn.Parameter(torch.rand(5, 7)) 12303*da0073e9SAndroid Build Coastguard Worker 12304*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 12305*da0073e9SAndroid Build Coastguard Worker return torch.mm(x, self.param) 12306*da0073e9SAndroid Build Coastguard Worker 12307*da0073e9SAndroid Build Coastguard Worker class TracedModule(torch.nn.Module): 12308*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 12309*da0073e9SAndroid Build Coastguard Worker super().__init__() 12310*da0073e9SAndroid Build Coastguard Worker self.param = torch.nn.Parameter(torch.rand(4, 5)) 12311*da0073e9SAndroid Build Coastguard Worker self.mod = PythonModule() 12312*da0073e9SAndroid Build Coastguard Worker 12313*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 12314*da0073e9SAndroid Build Coastguard Worker return self.mod(torch.mm(x, self.param)) + 1.0 12315*da0073e9SAndroid Build Coastguard Worker 12316*da0073e9SAndroid Build Coastguard Worker tm = torch.jit.trace(TracedModule(), torch.rand(3, 4)) 12317*da0073e9SAndroid Build Coastguard Worker 12318*da0073e9SAndroid Build Coastguard Worker FileCheck().check_not("value=<Tensor>").check("aten::mm")\ 12319*da0073e9SAndroid Build Coastguard Worker .check('prim::CallMethod[name="forward"]').check("aten::add") \ 12320*da0073e9SAndroid Build Coastguard Worker .run(str(tm.graph)) 12321*da0073e9SAndroid Build Coastguard Worker FileCheck().check("aten::mm").run(str(tm.mod.graph)) 12322*da0073e9SAndroid Build Coastguard Worker 12323*da0073e9SAndroid Build Coastguard Worker def test_op_dtype(self): 12324*da0073e9SAndroid Build Coastguard Worker 12325*da0073e9SAndroid Build Coastguard Worker def check_equal_and_dtype(a, b): 12326*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a, b) 12327*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.dtype, b.dtype) 12328*da0073e9SAndroid Build Coastguard Worker 12329*da0073e9SAndroid Build Coastguard Worker def fn(): 12330*da0073e9SAndroid Build Coastguard Worker a = torch.arange(10) 12331*da0073e9SAndroid Build Coastguard Worker b = torch.arange(10, dtype=torch.float) 12332*da0073e9SAndroid Build Coastguard Worker c = torch.arange(1, 10, 2) 12333*da0073e9SAndroid Build Coastguard Worker d = torch.arange(1, 10, 2, dtype=torch.float) 12334*da0073e9SAndroid Build Coastguard Worker e = torch.arange(1, 10., 2) 12335*da0073e9SAndroid Build Coastguard Worker f = torch.arange(1, 10., 2, dtype=torch.float) 12336*da0073e9SAndroid Build Coastguard Worker return a, b, c, d, e, f 12337*da0073e9SAndroid Build Coastguard Worker 12338*da0073e9SAndroid Build Coastguard Worker scripted_fn = torch.jit.script(fn) 12339*da0073e9SAndroid Build Coastguard Worker eager_out = fn() 12340*da0073e9SAndroid Build Coastguard Worker script_out = scripted_fn() 12341*da0073e9SAndroid Build Coastguard Worker for a, b in zip(eager_out, script_out): 12342*da0073e9SAndroid Build Coastguard Worker check_equal_and_dtype(a, b) 12343*da0073e9SAndroid Build Coastguard Worker 12344*da0073e9SAndroid Build Coastguard Worker def test_floor_div(self): 12345*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 12346*da0073e9SAndroid Build Coastguard Worker def foo(a, b): 12347*da0073e9SAndroid Build Coastguard Worker # type: (int, int) -> int 12348*da0073e9SAndroid Build Coastguard Worker return a // b 12349*da0073e9SAndroid Build Coastguard Worker for i in range(-8, 8): 12350*da0073e9SAndroid Build Coastguard Worker for j in range(-8, 8): 12351*da0073e9SAndroid Build Coastguard Worker if j != 0: 12352*da0073e9SAndroid Build Coastguard Worker self.assertEqual(foo(i, j), i // j) 12353*da0073e9SAndroid Build Coastguard Worker 12354*da0073e9SAndroid Build Coastguard Worker def test_floordiv(self): 12355*da0073e9SAndroid Build Coastguard Worker funcs_template = dedent(''' 12356*da0073e9SAndroid Build Coastguard Worker def fn(): 12357*da0073e9SAndroid Build Coastguard Worker ten = {a_construct} 12358*da0073e9SAndroid Build Coastguard Worker ten_or_scalar = {b_construct} 12359*da0073e9SAndroid Build Coastguard Worker return ten // ten_or_scalar, torch.floor_divide(ten, ten_or_scalar) 12360*da0073e9SAndroid Build Coastguard Worker ''') 12361*da0073e9SAndroid Build Coastguard Worker 12362*da0073e9SAndroid Build Coastguard Worker lhs = ["torch.tensor([5.5, 3.2])", "torch.tensor([2, 2])", "torch.tensor([3, 2])"] 12363*da0073e9SAndroid Build Coastguard Worker rhs = ["1.5", "2", "4", "1.1"] + lhs 12364*da0073e9SAndroid Build Coastguard Worker for tensor in lhs: 12365*da0073e9SAndroid Build Coastguard Worker for tensor_or_scalar in rhs: 12366*da0073e9SAndroid Build Coastguard Worker funcs_str = funcs_template.format(a_construct=tensor, b_construct=tensor_or_scalar) 12367*da0073e9SAndroid Build Coastguard Worker scope = {} 12368*da0073e9SAndroid Build Coastguard Worker execWrapper(funcs_str, globals(), scope) 12369*da0073e9SAndroid Build Coastguard Worker cu = torch.jit.CompilationUnit(funcs_str) 12370*da0073e9SAndroid Build Coastguard Worker f_script = cu.fn 12371*da0073e9SAndroid Build Coastguard Worker f = scope['fn'] 12372*da0073e9SAndroid Build Coastguard Worker self.assertEqual(f_script(), f()) 12373*da0073e9SAndroid Build Coastguard Worker 12374*da0073e9SAndroid Build Coastguard Worker def test_call_python_fn_from_script_fn(self): 12375*da0073e9SAndroid Build Coastguard Worker @torch.jit.ignore 12376*da0073e9SAndroid Build Coastguard Worker def python_fn(x): 12377*da0073e9SAndroid Build Coastguard Worker return torch.neg(x) 12378*da0073e9SAndroid Build Coastguard Worker 12379*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 12380*da0073e9SAndroid Build Coastguard Worker def script_fn(x): 12381*da0073e9SAndroid Build Coastguard Worker return python_fn(x) + 1 12382*da0073e9SAndroid Build Coastguard Worker 12383*da0073e9SAndroid Build Coastguard Worker # Note: the call to python_fn appears as `^python_fn()` and is called 12384*da0073e9SAndroid Build Coastguard Worker # as a PythonOp in the interpreter 12385*da0073e9SAndroid Build Coastguard Worker a = torch.tensor(1) 12386*da0073e9SAndroid Build Coastguard Worker self.assertEqual(script_fn(a), torch.tensor(0)) 12387*da0073e9SAndroid Build Coastguard Worker FileCheck().check("python_fn").run(str(script_fn.graph)) 12388*da0073e9SAndroid Build Coastguard Worker 12389*da0073e9SAndroid Build Coastguard Worker def test_call_python_mod_from_script_fn(self): 12390*da0073e9SAndroid Build Coastguard Worker class PythonModule(torch.nn.Module): 12391*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 12392*da0073e9SAndroid Build Coastguard Worker super().__init__() 12393*da0073e9SAndroid Build Coastguard Worker self.param = torch.nn.Parameter(torch.rand(5, 7)) 12394*da0073e9SAndroid Build Coastguard Worker 12395*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 12396*da0073e9SAndroid Build Coastguard Worker return torch.mm(x, self.param) 12397*da0073e9SAndroid Build Coastguard Worker 12398*da0073e9SAndroid Build Coastguard Worker pm = PythonModule() 12399*da0073e9SAndroid Build Coastguard Worker 12400*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 12401*da0073e9SAndroid Build Coastguard Worker def script_fn(x): 12402*da0073e9SAndroid Build Coastguard Worker return pm(x) + 1 12403*da0073e9SAndroid Build Coastguard Worker 12404*da0073e9SAndroid Build Coastguard Worker # Note: call to pm(x) appears as ^<python_value>() in the trace. 12405*da0073e9SAndroid Build Coastguard Worker # Parameters are NOT inlined. 12406*da0073e9SAndroid Build Coastguard Worker FileCheck().check("python_value").check("aten::add").run(str(script_fn.graph)) 12407*da0073e9SAndroid Build Coastguard Worker 12408*da0073e9SAndroid Build Coastguard Worker @_tmp_donotuse_dont_inline_everything 12409*da0073e9SAndroid Build Coastguard Worker def test_call_script_fn_from_script_fn(self): 12410*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 12411*da0073e9SAndroid Build Coastguard Worker def script_fn1(x): 12412*da0073e9SAndroid Build Coastguard Worker return torch.neg(x) 12413*da0073e9SAndroid Build Coastguard Worker 12414*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 12415*da0073e9SAndroid Build Coastguard Worker def script_fn(x): 12416*da0073e9SAndroid Build Coastguard Worker return script_fn1(x) + 1 12417*da0073e9SAndroid Build Coastguard Worker 12418*da0073e9SAndroid Build Coastguard Worker FileCheck().check("prim::CallFunction").run(str(script_fn.graph)) 12419*da0073e9SAndroid Build Coastguard Worker 12420*da0073e9SAndroid Build Coastguard Worker def test_call_script_mod_from_script_fn(self): 12421*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Cannot call a ScriptModule that is not a submodule of the caller"): 12422*da0073e9SAndroid Build Coastguard Worker class ScriptMod(torch.jit.ScriptModule): 12423*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 12424*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 12425*da0073e9SAndroid Build Coastguard Worker return torch.mm(x, torch.zeros([4, 3])) 12426*da0073e9SAndroid Build Coastguard Worker 12427*da0073e9SAndroid Build Coastguard Worker sm = ScriptMod() 12428*da0073e9SAndroid Build Coastguard Worker 12429*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 12430*da0073e9SAndroid Build Coastguard Worker def script_fn(x): 12431*da0073e9SAndroid Build Coastguard Worker return sm(x) + 1 12432*da0073e9SAndroid Build Coastguard Worker 12433*da0073e9SAndroid Build Coastguard Worker def test_call_python_fn_from_script_module(self): 12434*da0073e9SAndroid Build Coastguard Worker @torch.jit.ignore 12435*da0073e9SAndroid Build Coastguard Worker def python_fn(x): 12436*da0073e9SAndroid Build Coastguard Worker return torch.neg(x) 12437*da0073e9SAndroid Build Coastguard Worker 12438*da0073e9SAndroid Build Coastguard Worker class ScriptMod(torch.jit.ScriptModule): 12439*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 12440*da0073e9SAndroid Build Coastguard Worker super().__init__() 12441*da0073e9SAndroid Build Coastguard Worker self.param = torch.nn.Parameter(torch.rand(4, 3)) 12442*da0073e9SAndroid Build Coastguard Worker 12443*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 12444*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 12445*da0073e9SAndroid Build Coastguard Worker return python_fn(torch.mm(x, self.param)) 12446*da0073e9SAndroid Build Coastguard Worker 12447*da0073e9SAndroid Build Coastguard Worker sm = ScriptMod() 12448*da0073e9SAndroid Build Coastguard Worker FileCheck().check("aten::mm").check("python_fn") \ 12449*da0073e9SAndroid Build Coastguard Worker .run(str(sm.forward.graph)) 12450*da0073e9SAndroid Build Coastguard Worker 12451*da0073e9SAndroid Build Coastguard Worker def test_call_python_mod_from_script_module(self): 12452*da0073e9SAndroid Build Coastguard Worker class PythonMod(torch.nn.Module): 12453*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 12454*da0073e9SAndroid Build Coastguard Worker super().__init__() 12455*da0073e9SAndroid Build Coastguard Worker self.param = torch.nn.Parameter(torch.rand(3, 5)) 12456*da0073e9SAndroid Build Coastguard Worker 12457*da0073e9SAndroid Build Coastguard Worker @torch.jit.ignore 12458*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 12459*da0073e9SAndroid Build Coastguard Worker return torch.mm(x, self.param) 12460*da0073e9SAndroid Build Coastguard Worker 12461*da0073e9SAndroid Build Coastguard Worker class ScriptMod(torch.jit.ScriptModule): 12462*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 12463*da0073e9SAndroid Build Coastguard Worker super().__init__() 12464*da0073e9SAndroid Build Coastguard Worker self.param = torch.nn.Parameter(torch.rand(4, 3)) 12465*da0073e9SAndroid Build Coastguard Worker self.pm = PythonMod() 12466*da0073e9SAndroid Build Coastguard Worker 12467*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 12468*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 12469*da0073e9SAndroid Build Coastguard Worker return self.pm(torch.mm(x, self.param)) 12470*da0073e9SAndroid Build Coastguard Worker 12471*da0073e9SAndroid Build Coastguard Worker sm = ScriptMod() 12472*da0073e9SAndroid Build Coastguard Worker # Note: the call into PythonMod appears as ^forward(). Parameters 12473*da0073e9SAndroid Build Coastguard Worker # are NOT inlined 12474*da0073e9SAndroid Build Coastguard Worker FileCheck().check("aten::mm").check("forward").run(str(sm.graph)) 12475*da0073e9SAndroid Build Coastguard Worker 12476*da0073e9SAndroid Build Coastguard Worker @_tmp_donotuse_dont_inline_everything 12477*da0073e9SAndroid Build Coastguard Worker def test_call_script_fn_from_script_module(self): 12478*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 12479*da0073e9SAndroid Build Coastguard Worker def script_fn(x): 12480*da0073e9SAndroid Build Coastguard Worker return torch.neg(x) 12481*da0073e9SAndroid Build Coastguard Worker 12482*da0073e9SAndroid Build Coastguard Worker class ScriptMod(torch.jit.ScriptModule): 12483*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 12484*da0073e9SAndroid Build Coastguard Worker super().__init__() 12485*da0073e9SAndroid Build Coastguard Worker self.param = torch.nn.Parameter(torch.rand(4, 3)) 12486*da0073e9SAndroid Build Coastguard Worker 12487*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 12488*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 12489*da0073e9SAndroid Build Coastguard Worker return script_fn(torch.mm(x, self.param)) 12490*da0073e9SAndroid Build Coastguard Worker 12491*da0073e9SAndroid Build Coastguard Worker sm = ScriptMod() 12492*da0073e9SAndroid Build Coastguard Worker graph = (sm.forward.graph) 12493*da0073e9SAndroid Build Coastguard Worker FileCheck().check("aten::mm").check("prim::CallFunction").run(str(graph)) 12494*da0073e9SAndroid Build Coastguard Worker 12495*da0073e9SAndroid Build Coastguard Worker @_tmp_donotuse_dont_inline_everything 12496*da0073e9SAndroid Build Coastguard Worker def test_call_script_mod_from_script_module(self): 12497*da0073e9SAndroid Build Coastguard Worker class ScriptMod1(torch.jit.ScriptModule): 12498*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 12499*da0073e9SAndroid Build Coastguard Worker super().__init__() 12500*da0073e9SAndroid Build Coastguard Worker self.param = torch.nn.Parameter(torch.rand(3, 5)) 12501*da0073e9SAndroid Build Coastguard Worker 12502*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 12503*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 12504*da0073e9SAndroid Build Coastguard Worker return torch.mm(x, self.param) 12505*da0073e9SAndroid Build Coastguard Worker 12506*da0073e9SAndroid Build Coastguard Worker class ScriptMod(torch.jit.ScriptModule): 12507*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 12508*da0073e9SAndroid Build Coastguard Worker super().__init__() 12509*da0073e9SAndroid Build Coastguard Worker self.param = torch.nn.Parameter(torch.rand(4, 3)) 12510*da0073e9SAndroid Build Coastguard Worker self.tm = ScriptMod1() 12511*da0073e9SAndroid Build Coastguard Worker 12512*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 12513*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 12514*da0073e9SAndroid Build Coastguard Worker return self.tm(torch.mm(x, self.param)) 12515*da0073e9SAndroid Build Coastguard Worker 12516*da0073e9SAndroid Build Coastguard Worker sm = ScriptMod() 12517*da0073e9SAndroid Build Coastguard Worker # Note: the parameters from both modules should appear in the flattened 12518*da0073e9SAndroid Build Coastguard Worker # input list to the graph. The mm op from ScriptMod1 should be properly 12519*da0073e9SAndroid Build Coastguard Worker # inlined 12520*da0073e9SAndroid Build Coastguard Worker # 3 % values in graph input lists, two mms in body 12521*da0073e9SAndroid Build Coastguard Worker FileCheck().check_count('%', 3).check(":").check_count("mm", 1).check("prim::CallMethod").run(str(sm.graph)) 12522*da0073e9SAndroid Build Coastguard Worker 12523*da0073e9SAndroid Build Coastguard Worker def test_module_with_params_called_fails(self): 12524*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Cannot call a ScriptModule that is not a submodule of the caller"): 12525*da0073e9SAndroid Build Coastguard Worker class ScriptMod(torch.jit.ScriptModule): 12526*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 12527*da0073e9SAndroid Build Coastguard Worker super().__init__() 12528*da0073e9SAndroid Build Coastguard Worker self.param = torch.nn.Parameter(torch.rand(3, 3)) 12529*da0073e9SAndroid Build Coastguard Worker 12530*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 12531*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 12532*da0073e9SAndroid Build Coastguard Worker return torch.mm(x, self.param) 12533*da0073e9SAndroid Build Coastguard Worker 12534*da0073e9SAndroid Build Coastguard Worker sm = ScriptMod() 12535*da0073e9SAndroid Build Coastguard Worker 12536*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 12537*da0073e9SAndroid Build Coastguard Worker def some_func(x): 12538*da0073e9SAndroid Build Coastguard Worker return sm(x) 12539*da0073e9SAndroid Build Coastguard Worker 12540*da0073e9SAndroid Build Coastguard Worker def test_tuple_index_to_list(self): 12541*da0073e9SAndroid Build Coastguard Worker def test_non_constant_input(a): 12542*da0073e9SAndroid Build Coastguard Worker # type: (bool) -> int 12543*da0073e9SAndroid Build Coastguard Worker if a: 12544*da0073e9SAndroid Build Coastguard Worker b = 1 12545*da0073e9SAndroid Build Coastguard Worker else: 12546*da0073e9SAndroid Build Coastguard Worker b = 0 12547*da0073e9SAndroid Build Coastguard Worker c = (0, 1) 12548*da0073e9SAndroid Build Coastguard Worker return c[b] 12549*da0073e9SAndroid Build Coastguard Worker 12550*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_non_constant_input, (True,)) 12551*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_non_constant_input, (False,)) 12552*da0073e9SAndroid Build Coastguard Worker 12553*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "because we cannot resolve the output type"): 12554*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 12555*da0073e9SAndroid Build Coastguard Worker def test_non_constant_input(a): 12556*da0073e9SAndroid Build Coastguard Worker # type: (bool) -> None 12557*da0073e9SAndroid Build Coastguard Worker if a: 12558*da0073e9SAndroid Build Coastguard Worker b = 1 12559*da0073e9SAndroid Build Coastguard Worker else: 12560*da0073e9SAndroid Build Coastguard Worker b = 0 12561*da0073e9SAndroid Build Coastguard Worker c = (0, 1.1) 12562*da0073e9SAndroid Build Coastguard Worker print(c[b]) 12563*da0073e9SAndroid Build Coastguard Worker 12564*da0073e9SAndroid Build Coastguard Worker def test_tuple_indexing(self): 12565*da0073e9SAndroid Build Coastguard Worker def tuple_index(a): 12566*da0073e9SAndroid Build Coastguard Worker if bool(a): 12567*da0073e9SAndroid Build Coastguard Worker b = (1, 2) 12568*da0073e9SAndroid Build Coastguard Worker else: 12569*da0073e9SAndroid Build Coastguard Worker b = (0, 2) 12570*da0073e9SAndroid Build Coastguard Worker return b[-2], b[1] 12571*da0073e9SAndroid Build Coastguard Worker 12572*da0073e9SAndroid Build Coastguard Worker self.checkScript(tuple_index, (torch.tensor([0]),)) 12573*da0073e9SAndroid Build Coastguard Worker self.checkScript(tuple_index, (torch.tensor([1]),)) 12574*da0073e9SAndroid Build Coastguard Worker self.checkScript(tuple_index, (torch.tensor([1]),), optimize=True) 12575*da0073e9SAndroid Build Coastguard Worker tuple_comp = torch.jit.script(tuple_index) 12576*da0073e9SAndroid Build Coastguard Worker FileCheck().check_count("TupleIndex", 2, exactly=True).run(str(tuple_comp.graph)) 12577*da0073e9SAndroid Build Coastguard Worker 12578*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "index must be an integer"): 12579*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 12580*da0073e9SAndroid Build Coastguard Worker def test_indexing_float(): 12581*da0073e9SAndroid Build Coastguard Worker c = (1, 2) 12582*da0073e9SAndroid Build Coastguard Worker return c[0.1] 12583*da0073e9SAndroid Build Coastguard Worker 12584*da0073e9SAndroid Build Coastguard Worker def test_indexing_out_of_bounds_pos(): 12585*da0073e9SAndroid Build Coastguard Worker c = (1, 2) 12586*da0073e9SAndroid Build Coastguard Worker return c[2] 12587*da0073e9SAndroid Build Coastguard Worker 12588*da0073e9SAndroid Build Coastguard Worker self.checkScriptRaisesRegex(test_indexing_out_of_bounds_pos, (), Exception, 12589*da0073e9SAndroid Build Coastguard Worker "out of range") 12590*da0073e9SAndroid Build Coastguard Worker 12591*da0073e9SAndroid Build Coastguard Worker def test_indexing_out_of_bounds_neg(): 12592*da0073e9SAndroid Build Coastguard Worker c = (1, 2) 12593*da0073e9SAndroid Build Coastguard Worker return c[-3] 12594*da0073e9SAndroid Build Coastguard Worker 12595*da0073e9SAndroid Build Coastguard Worker self.checkScriptRaisesRegex(test_indexing_out_of_bounds_pos, (), Exception, 12596*da0073e9SAndroid Build Coastguard Worker "out of range") 12597*da0073e9SAndroid Build Coastguard Worker 12598*da0073e9SAndroid Build Coastguard Worker def negative_index(): 12599*da0073e9SAndroid Build Coastguard Worker tup = (1, 2, 3, 4) 12600*da0073e9SAndroid Build Coastguard Worker return tup[-1] 12601*da0073e9SAndroid Build Coastguard Worker 12602*da0073e9SAndroid Build Coastguard Worker self.checkScript(negative_index, []) 12603*da0073e9SAndroid Build Coastguard Worker 12604*da0073e9SAndroid Build Coastguard Worker def really_negative_index(): 12605*da0073e9SAndroid Build Coastguard Worker tup = (1, 2, 3, 4) 12606*da0073e9SAndroid Build Coastguard Worker return tup[-100] 12607*da0073e9SAndroid Build Coastguard Worker 12608*da0073e9SAndroid Build Coastguard Worker self.checkScriptRaisesRegex(really_negative_index, [], Exception, "index out of range") 12609*da0073e9SAndroid Build Coastguard Worker 12610*da0073e9SAndroid Build Coastguard Worker def negative_slice(): 12611*da0073e9SAndroid Build Coastguard Worker tup = (1, 2, 3, 4) 12612*da0073e9SAndroid Build Coastguard Worker return tup[-3:4] 12613*da0073e9SAndroid Build Coastguard Worker 12614*da0073e9SAndroid Build Coastguard Worker self.checkScript(negative_slice, []) 12615*da0073e9SAndroid Build Coastguard Worker 12616*da0073e9SAndroid Build Coastguard Worker def really_slice_out_of_bounds(): 12617*da0073e9SAndroid Build Coastguard Worker tup = (1, 2, 3, 4) 12618*da0073e9SAndroid Build Coastguard Worker return tup[-300:4000] 12619*da0073e9SAndroid Build Coastguard Worker 12620*da0073e9SAndroid Build Coastguard Worker self.checkScript(really_slice_out_of_bounds, []) 12621*da0073e9SAndroid Build Coastguard Worker 12622*da0073e9SAndroid Build Coastguard Worker def test_namedtuple_attr(self): 12623*da0073e9SAndroid Build Coastguard Worker def f(x): 12624*da0073e9SAndroid Build Coastguard Worker return x.max(dim=1).indices + torch.max(x, dim=1).indices 12625*da0073e9SAndroid Build Coastguard Worker 12626*da0073e9SAndroid Build Coastguard Worker self.checkScript(f, (torch.rand(20, 20, 20),), optimize=True) 12627*da0073e9SAndroid Build Coastguard Worker 12628*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "object has no attribute or method"): 12629*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 12630*da0073e9SAndroid Build Coastguard Worker def g1(x): 12631*da0073e9SAndroid Build Coastguard Worker return x.max(dim=1).unknown_symbol 12632*da0073e9SAndroid Build Coastguard Worker 12633*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "object has no attribute or method"): 12634*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 12635*da0073e9SAndroid Build Coastguard Worker def g2(x): 12636*da0073e9SAndroid Build Coastguard Worker print((x, x, x).__doc__) 12637*da0073e9SAndroid Build Coastguard Worker return x 12638*da0073e9SAndroid Build Coastguard Worker 12639*da0073e9SAndroid Build Coastguard Worker def test_tuple_len(self): 12640*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 12641*da0073e9SAndroid Build Coastguard Worker def foo(): 12642*da0073e9SAndroid Build Coastguard Worker return len((1, "str", None)) 12643*da0073e9SAndroid Build Coastguard Worker 12644*da0073e9SAndroid Build Coastguard Worker self.assertEqual(foo(), 3) 12645*da0073e9SAndroid Build Coastguard Worker 12646*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 12647*da0073e9SAndroid Build Coastguard Worker def test_indexing_end_out_of_bounds(): 12648*da0073e9SAndroid Build Coastguard Worker c = (1, 2) 12649*da0073e9SAndroid Build Coastguard Worker return c[2:10] 12650*da0073e9SAndroid Build Coastguard Worker 12651*da0073e9SAndroid Build Coastguard Worker self.assertEqual(test_indexing_end_out_of_bounds(), ()) 12652*da0073e9SAndroid Build Coastguard Worker 12653*da0073e9SAndroid Build Coastguard Worker def test_lower_nested_tuples(self): 12654*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 12655*da0073e9SAndroid Build Coastguard Worker def test(): 12656*da0073e9SAndroid Build Coastguard Worker return ((1, 2), 3) 12657*da0073e9SAndroid Build Coastguard Worker 12658*da0073e9SAndroid Build Coastguard Worker self.run_pass('constant_propagation', test.graph) 12659*da0073e9SAndroid Build Coastguard Worker FileCheck().check("prim::Constant").check_not("TupleConstruct").run(test.graph) 12660*da0073e9SAndroid Build Coastguard Worker # fails if a tuple can't be lowered 12661*da0073e9SAndroid Build Coastguard Worker self.run_pass('lower_all_tuples', test.graph) 12662*da0073e9SAndroid Build Coastguard Worker 12663*da0073e9SAndroid Build Coastguard Worker def test_unwrap_optional_builtin(self): 12664*da0073e9SAndroid Build Coastguard Worker def test(x): 12665*da0073e9SAndroid Build Coastguard Worker # type: (Optional[int]) -> int 12666*da0073e9SAndroid Build Coastguard Worker x = torch.jit._unwrap_optional(x) 12667*da0073e9SAndroid Build Coastguard Worker x = x + x # noqa: T484 12668*da0073e9SAndroid Build Coastguard Worker return x 12669*da0073e9SAndroid Build Coastguard Worker 12670*da0073e9SAndroid Build Coastguard Worker self.checkScript(test, (3,)) 12671*da0073e9SAndroid Build Coastguard Worker 12672*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(AssertionError, "Unwrapping null optional"): 12673*da0073e9SAndroid Build Coastguard Worker test(None) 12674*da0073e9SAndroid Build Coastguard Worker 12675*da0073e9SAndroid Build Coastguard Worker test_script = torch.jit.script(test) 12676*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Unwrapping null optional"): 12677*da0073e9SAndroid Build Coastguard Worker test_script(None) 12678*da0073e9SAndroid Build Coastguard Worker 12679*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 12680*da0073e9SAndroid Build Coastguard Worker def test_test(): 12681*da0073e9SAndroid Build Coastguard Worker return torch.jit._unwrap_optional(1) 12682*da0073e9SAndroid Build Coastguard Worker 12683*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r"could not be inferred from actual type None"): 12684*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 12685*da0073e9SAndroid Build Coastguard Worker def test_no_type(): 12686*da0073e9SAndroid Build Coastguard Worker # type: () -> int 12687*da0073e9SAndroid Build Coastguard Worker return torch.jit._unwrap_optional(None) 12688*da0073e9SAndroid Build Coastguard Worker 12689*da0073e9SAndroid Build Coastguard Worker def test_indexing_error(self): 12690*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "'int' object is not subscriptable"): 12691*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 12692*da0073e9SAndroid Build Coastguard Worker def test_wrong_type(): 12693*da0073e9SAndroid Build Coastguard Worker a = 8 12694*da0073e9SAndroid Build Coastguard Worker return a[0] 12695*da0073e9SAndroid Build Coastguard Worker 12696*da0073e9SAndroid Build Coastguard Worker def test_unsupported_builtin_error(self): 12697*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 12698*da0073e9SAndroid Build Coastguard Worker "Python builtin <built-in function hypot> is currently"): 12699*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 12700*da0073e9SAndroid Build Coastguard Worker def test_unsupported(a): 12701*da0073e9SAndroid Build Coastguard Worker return math.hypot(a, 2.0) 12702*da0073e9SAndroid Build Coastguard Worker 12703*da0073e9SAndroid Build Coastguard Worker def test_annotated_script_fn(self): 12704*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 12705*da0073e9SAndroid Build Coastguard Worker def foo(x, y, z): 12706*da0073e9SAndroid Build Coastguard Worker # type: (Tensor, Tuple[Tensor, Tensor, Tensor], Tuple[Tensor, Tuple[Tensor, Tensor]]) -> Tensor 12707*da0073e9SAndroid Build Coastguard Worker return x 12708*da0073e9SAndroid Build Coastguard Worker 12709*da0073e9SAndroid Build Coastguard Worker self.assertExpected(str(foo.schema)) 12710*da0073e9SAndroid Build Coastguard Worker 12711*da0073e9SAndroid Build Coastguard Worker def test_annotated_script_method(self): 12712*da0073e9SAndroid Build Coastguard Worker class SM(torch.jit.ScriptModule): 12713*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 12714*da0073e9SAndroid Build Coastguard Worker def forward(self, x, y): 12715*da0073e9SAndroid Build Coastguard Worker # type: (Tuple[Tensor, Tensor], Tensor) -> Tuple[Tensor, Tensor, Tensor] 12716*da0073e9SAndroid Build Coastguard Worker return y, y, y 12717*da0073e9SAndroid Build Coastguard Worker 12718*da0073e9SAndroid Build Coastguard Worker sm = SM() 12719*da0073e9SAndroid Build Coastguard Worker 12720*da0073e9SAndroid Build Coastguard Worker self.assertExpectedStripMangled(str(sm.forward.schema)) 12721*da0073e9SAndroid Build Coastguard Worker 12722*da0073e9SAndroid Build Coastguard Worker def test_annotated_script_fn_return_mismatch(self): 12723*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "but is actually of type"): 12724*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 12725*da0073e9SAndroid Build Coastguard Worker def return_tup(x): 12726*da0073e9SAndroid Build Coastguard Worker # type: (Tensor) -> Tuple[Tuple[Tensor, Tensor], Tensor] 12727*da0073e9SAndroid Build Coastguard Worker return x, x # noqa: T484 12728*da0073e9SAndroid Build Coastguard Worker 12729*da0073e9SAndroid Build Coastguard Worker def test_annotated_script_fn_arg_mismatch(self): 12730*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r"Arguments for call are not valid"): 12731*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 12732*da0073e9SAndroid Build Coastguard Worker def tuple_arg(x): 12733*da0073e9SAndroid Build Coastguard Worker # type: (Tuple[Tensor, Tensor]) -> Tensor 12734*da0073e9SAndroid Build Coastguard Worker return x + 1 # noqa: T484 12735*da0073e9SAndroid Build Coastguard Worker 12736*da0073e9SAndroid Build Coastguard Worker def test_script_non_tensor_args_outputs(self): 12737*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 12738*da0073e9SAndroid Build Coastguard Worker def fn(x, y): 12739*da0073e9SAndroid Build Coastguard Worker # type: (Tensor, float) -> float 12740*da0073e9SAndroid Build Coastguard Worker return float((x + y).sum()) 12741*da0073e9SAndroid Build Coastguard Worker 12742*da0073e9SAndroid Build Coastguard Worker x = torch.ones(2, 2) 12743*da0073e9SAndroid Build Coastguard Worker z = fn(x, 1) 12744*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(z, float) 12745*da0073e9SAndroid Build Coastguard Worker self.assertEqual(z, 8.) 12746*da0073e9SAndroid Build Coastguard Worker 12747*da0073e9SAndroid Build Coastguard Worker @unittest.skip('https://github.com/pytorch/pytorch/issues/9595') 12748*da0073e9SAndroid Build Coastguard Worker def test_inline_and_run_annotated_script_fn(self): 12749*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 12750*da0073e9SAndroid Build Coastguard Worker def to_inline(x, y): 12751*da0073e9SAndroid Build Coastguard Worker # type: (Tuple[Tensor, Tensor], Tensor) -> Tensor 12752*da0073e9SAndroid Build Coastguard Worker return y 12753*da0073e9SAndroid Build Coastguard Worker 12754*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 12755*da0073e9SAndroid Build Coastguard Worker def some_func(x): 12756*da0073e9SAndroid Build Coastguard Worker return to_inline((x, x), x) 12757*da0073e9SAndroid Build Coastguard Worker 12758*da0073e9SAndroid Build Coastguard Worker x = torch.rand(3, 4) 12759*da0073e9SAndroid Build Coastguard Worker self.assertEqual(some_func(x), x) 12760*da0073e9SAndroid Build Coastguard Worker 12761*da0073e9SAndroid Build Coastguard Worker def _make_filereader_test_file(self): 12762*da0073e9SAndroid Build Coastguard Worker filename = tempfile.mktemp() 12763*da0073e9SAndroid Build Coastguard Worker writer = torch._C.PyTorchFileWriter(filename) 12764*da0073e9SAndroid Build Coastguard Worker buffers = [os.urandom(size) for size in [random.randint(1, 100) for i in range(20)]] 12765*da0073e9SAndroid Build Coastguard Worker offsets = [] 12766*da0073e9SAndroid Build Coastguard Worker for i, buf in enumerate(buffers): 12767*da0073e9SAndroid Build Coastguard Worker writer.write_record(str(i), buf, len(buf)) 12768*da0073e9SAndroid Build Coastguard Worker offsets.append(i) 12769*da0073e9SAndroid Build Coastguard Worker serialized_offsets = pickle.dumps(offsets) 12770*da0073e9SAndroid Build Coastguard Worker writer.write_record("meta", serialized_offsets, len(serialized_offsets)) 12771*da0073e9SAndroid Build Coastguard Worker writer.write_end_of_file() 12772*da0073e9SAndroid Build Coastguard Worker return filename, buffers, serialized_offsets 12773*da0073e9SAndroid Build Coastguard Worker 12774*da0073e9SAndroid Build Coastguard Worker def test_file_format_serialization(self): 12775*da0073e9SAndroid Build Coastguard Worker filename, buffers, serialized_offsets = self._make_filereader_test_file() 12776*da0073e9SAndroid Build Coastguard Worker 12777*da0073e9SAndroid Build Coastguard Worker reader = torch._C.PyTorchFileReader(filename) 12778*da0073e9SAndroid Build Coastguard Worker serialized_offsets_read = reader.get_record("meta") 12779*da0073e9SAndroid Build Coastguard Worker parsed_serialized_offsets = pickle.loads(serialized_offsets) 12780*da0073e9SAndroid Build Coastguard Worker 12781*da0073e9SAndroid Build Coastguard Worker for i, offset in enumerate(parsed_serialized_offsets): 12782*da0073e9SAndroid Build Coastguard Worker data = reader.get_record(str(offset)) 12783*da0073e9SAndroid Build Coastguard Worker assert data == buffers[i] 12784*da0073e9SAndroid Build Coastguard Worker 12785*da0073e9SAndroid Build Coastguard Worker def test_file_reader_no_memory_leak(self): 12786*da0073e9SAndroid Build Coastguard Worker num_iters = 10000 12787*da0073e9SAndroid Build Coastguard Worker filename, _, _ = self._make_filereader_test_file() 12788*da0073e9SAndroid Build Coastguard Worker 12789*da0073e9SAndroid Build Coastguard Worker # Load from filename 12790*da0073e9SAndroid Build Coastguard Worker tracemalloc.start() 12791*da0073e9SAndroid Build Coastguard Worker for i in range(num_iters): 12792*da0073e9SAndroid Build Coastguard Worker torch._C.PyTorchFileReader(filename) 12793*da0073e9SAndroid Build Coastguard Worker _, peak_from_string = tracemalloc.get_traced_memory() 12794*da0073e9SAndroid Build Coastguard Worker tracemalloc.stop() 12795*da0073e9SAndroid Build Coastguard Worker 12796*da0073e9SAndroid Build Coastguard Worker # Load from stream 12797*da0073e9SAndroid Build Coastguard Worker tracemalloc.start() 12798*da0073e9SAndroid Build Coastguard Worker with open(filename, 'rb') as f: 12799*da0073e9SAndroid Build Coastguard Worker for i in range(num_iters): 12800*da0073e9SAndroid Build Coastguard Worker f.seek(0) 12801*da0073e9SAndroid Build Coastguard Worker torch._C.PyTorchFileReader(f) 12802*da0073e9SAndroid Build Coastguard Worker _, peak_from_file = tracemalloc.get_traced_memory() 12803*da0073e9SAndroid Build Coastguard Worker tracemalloc.stop() 12804*da0073e9SAndroid Build Coastguard Worker 12805*da0073e9SAndroid Build Coastguard Worker # Check if the peak sizes at most differ by an empirically obtained factor 12806*da0073e9SAndroid Build Coastguard Worker self.assertLess(peak_from_file, peak_from_string * 500) 12807*da0073e9SAndroid Build Coastguard Worker 12808*da0073e9SAndroid Build Coastguard Worker # for each type, the input type annotation and corresponding return type annotation 12809*da0073e9SAndroid Build Coastguard Worker def type_input_return_pairs(self): 12810*da0073e9SAndroid Build Coastguard Worker return [ 12811*da0073e9SAndroid Build Coastguard Worker ('Tensor', 'Tensor'), 12812*da0073e9SAndroid Build Coastguard Worker ('torch.Tensor', 'Tensor'), 12813*da0073e9SAndroid Build Coastguard Worker ('str', 'str'), 12814*da0073e9SAndroid Build Coastguard Worker ('int', 'int'), 12815*da0073e9SAndroid Build Coastguard Worker ('bool', 'bool'), 12816*da0073e9SAndroid Build Coastguard Worker ('BroadcastingList3[float]', 'List[float]'), 12817*da0073e9SAndroid Build Coastguard Worker ('BroadcastingList2[int]', 'List[int]'), 12818*da0073e9SAndroid Build Coastguard Worker ('List[int]', 'List[int]'), 12819*da0073e9SAndroid Build Coastguard Worker ('Optional[int]', 'Optional[int]'), 12820*da0073e9SAndroid Build Coastguard Worker ] 12821*da0073e9SAndroid Build Coastguard Worker 12822*da0073e9SAndroid Build Coastguard Worker # replacing code input & return type pair 12823*da0073e9SAndroid Build Coastguard Worker def format_code(self, code, pair): 12824*da0073e9SAndroid Build Coastguard Worker return code.format(input=pair[0], output=pair[1]) 12825*da0073e9SAndroid Build Coastguard Worker 12826*da0073e9SAndroid Build Coastguard Worker # ***** Type annotation tests **** 12827*da0073e9SAndroid Build Coastguard Worker # Test combinations of: 12828*da0073e9SAndroid Build Coastguard Worker # {String frontend, Python AST Frontend} 12829*da0073e9SAndroid Build Coastguard Worker # {Python 3-style type annotations, MyPy-style type comments} 12830*da0073e9SAndroid Build Coastguard Worker # {Script method, Script function} 12831*da0073e9SAndroid Build Coastguard Worker 12832*da0073e9SAndroid Build Coastguard Worker # String frontend , Python 3-style type annotations , Script function 12833*da0073e9SAndroid Build Coastguard Worker def test_annot_string_py3_fn(self): 12834*da0073e9SAndroid Build Coastguard Worker code = ''' 12835*da0073e9SAndroid Build Coastguard Worker def foo(x : {input}, y : Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}]: 12836*da0073e9SAndroid Build Coastguard Worker return x, x 12837*da0073e9SAndroid Build Coastguard Worker ''' 12838*da0073e9SAndroid Build Coastguard Worker test_str = [] 12839*da0073e9SAndroid Build Coastguard Worker for pair in self.type_input_return_pairs(): 12840*da0073e9SAndroid Build Coastguard Worker cu = torch.jit.CompilationUnit(self.format_code(code, pair)) 12841*da0073e9SAndroid Build Coastguard Worker test_str.append(str(cu.foo.schema)) 12842*da0073e9SAndroid Build Coastguard Worker self.assertExpected("\n".join(test_str) + "\n") 12843*da0073e9SAndroid Build Coastguard Worker 12844*da0073e9SAndroid Build Coastguard Worker # String frontend , Python 3-style type annotations , Script method 12845*da0073e9SAndroid Build Coastguard Worker def test_annot_string_py3_method(self): 12846*da0073e9SAndroid Build Coastguard Worker class TestModule(torch.jit.ScriptModule): 12847*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 12848*da0073e9SAndroid Build Coastguard Worker super().__init__() 12849*da0073e9SAndroid Build Coastguard Worker 12850*da0073e9SAndroid Build Coastguard Worker code = ''' 12851*da0073e9SAndroid Build Coastguard Worker def foo(self, x : {input}, y : Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}]: 12852*da0073e9SAndroid Build Coastguard Worker return x, x 12853*da0073e9SAndroid Build Coastguard Worker ''' 12854*da0073e9SAndroid Build Coastguard Worker test_str = [] 12855*da0073e9SAndroid Build Coastguard Worker for pair in self.type_input_return_pairs(): 12856*da0073e9SAndroid Build Coastguard Worker # clear the class registry as we will be defining foo multiple times 12857*da0073e9SAndroid Build Coastguard Worker jit_utils.clear_class_registry() 12858*da0073e9SAndroid Build Coastguard Worker tm = TestModule() 12859*da0073e9SAndroid Build Coastguard Worker tm.define(self.format_code(code, pair)) 12860*da0073e9SAndroid Build Coastguard Worker test_str.append(str(tm.foo.schema)) 12861*da0073e9SAndroid Build Coastguard Worker self.assertExpectedStripMangled("\n".join(test_str) + "\n") 12862*da0073e9SAndroid Build Coastguard Worker 12863*da0073e9SAndroid Build Coastguard Worker # String frontend , MyPy-style type comments , Script function 12864*da0073e9SAndroid Build Coastguard Worker def test_annot_string_mypy_fn(self): 12865*da0073e9SAndroid Build Coastguard Worker code = ''' 12866*da0073e9SAndroid Build Coastguard Worker def foo(x, y): 12867*da0073e9SAndroid Build Coastguard Worker # type: ({input}, Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}] 12868*da0073e9SAndroid Build Coastguard Worker return x, x 12869*da0073e9SAndroid Build Coastguard Worker ''' 12870*da0073e9SAndroid Build Coastguard Worker test_str = [] 12871*da0073e9SAndroid Build Coastguard Worker for pair in self.type_input_return_pairs(): 12872*da0073e9SAndroid Build Coastguard Worker cu = torch.jit.CompilationUnit(self.format_code(code, pair)) 12873*da0073e9SAndroid Build Coastguard Worker test_str.append(str(cu.foo.schema)) 12874*da0073e9SAndroid Build Coastguard Worker self.assertExpectedStripMangled("\n".join(test_str) + "\n") 12875*da0073e9SAndroid Build Coastguard Worker 12876*da0073e9SAndroid Build Coastguard Worker # String frontend , MyPy-style type comments , Script method 12877*da0073e9SAndroid Build Coastguard Worker def test_annot_string_mypy_method(self): 12878*da0073e9SAndroid Build Coastguard Worker class TestModule(torch.jit.ScriptModule): 12879*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 12880*da0073e9SAndroid Build Coastguard Worker super().__init__() 12881*da0073e9SAndroid Build Coastguard Worker 12882*da0073e9SAndroid Build Coastguard Worker code = ''' 12883*da0073e9SAndroid Build Coastguard Worker def foo(self, x, y): 12884*da0073e9SAndroid Build Coastguard Worker # type: ({input}, Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}] 12885*da0073e9SAndroid Build Coastguard Worker return x, x 12886*da0073e9SAndroid Build Coastguard Worker ''' 12887*da0073e9SAndroid Build Coastguard Worker 12888*da0073e9SAndroid Build Coastguard Worker test_str = [] 12889*da0073e9SAndroid Build Coastguard Worker for pair in self.type_input_return_pairs(): 12890*da0073e9SAndroid Build Coastguard Worker # clear the class registry as we will be defining foo multiple times 12891*da0073e9SAndroid Build Coastguard Worker jit_utils.clear_class_registry() 12892*da0073e9SAndroid Build Coastguard Worker tm = TestModule() 12893*da0073e9SAndroid Build Coastguard Worker tm.define(self.format_code(code, pair)) 12894*da0073e9SAndroid Build Coastguard Worker test_str.append(str(tm.foo.schema)) 12895*da0073e9SAndroid Build Coastguard Worker self.assertExpectedStripMangled("\n".join(test_str) + "\n") 12896*da0073e9SAndroid Build Coastguard Worker 12897*da0073e9SAndroid Build Coastguard Worker # Python AST Frontend , Python 3-style type annotations , Script function 12898*da0073e9SAndroid Build Coastguard Worker def test_annot_ast_py3_fn(self): 12899*da0073e9SAndroid Build Coastguard Worker code = dedent(''' 12900*da0073e9SAndroid Build Coastguard Worker from typing import Tuple, List, Optional 12901*da0073e9SAndroid Build Coastguard Worker from torch import Tensor 12902*da0073e9SAndroid Build Coastguard Worker from torch.jit.annotations import BroadcastingList2, BroadcastingList3 12903*da0073e9SAndroid Build Coastguard Worker import torch 12904*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 12905*da0073e9SAndroid Build Coastguard Worker def foo(x : {input}, y : Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}]: 12906*da0073e9SAndroid Build Coastguard Worker return x, x 12907*da0073e9SAndroid Build Coastguard Worker ''') 12908*da0073e9SAndroid Build Coastguard Worker test_str = [] 12909*da0073e9SAndroid Build Coastguard Worker for pair in self.type_input_return_pairs(): 12910*da0073e9SAndroid Build Coastguard Worker fn = jit_utils._get_py3_code(self.format_code(code, pair), 'foo') 12911*da0073e9SAndroid Build Coastguard Worker test_str.append(str(fn.schema)) 12912*da0073e9SAndroid Build Coastguard Worker self.assertExpectedStripMangled("\n".join(test_str) + "\n") 12913*da0073e9SAndroid Build Coastguard Worker 12914*da0073e9SAndroid Build Coastguard Worker def test_multiline_annot_ast_py3_fn(self): 12915*da0073e9SAndroid Build Coastguard Worker code = dedent(''' 12916*da0073e9SAndroid Build Coastguard Worker from typing import Tuple, List, Optional 12917*da0073e9SAndroid Build Coastguard Worker from torch import Tensor 12918*da0073e9SAndroid Build Coastguard Worker from torch.jit.annotations import BroadcastingList2, BroadcastingList3 12919*da0073e9SAndroid Build Coastguard Worker import torch 12920*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 12921*da0073e9SAndroid Build Coastguard Worker def foo(x, # type: {input} 12922*da0073e9SAndroid Build Coastguard Worker y # type: Tuple[Tensor, Tensor] 12923*da0073e9SAndroid Build Coastguard Worker ): 12924*da0073e9SAndroid Build Coastguard Worker # type: (...) -> Tuple[{output}, {output}] 12925*da0073e9SAndroid Build Coastguard Worker return x, x 12926*da0073e9SAndroid Build Coastguard Worker ''') 12927*da0073e9SAndroid Build Coastguard Worker test_str = [] 12928*da0073e9SAndroid Build Coastguard Worker 12929*da0073e9SAndroid Build Coastguard Worker for pair in self.type_input_return_pairs(): 12930*da0073e9SAndroid Build Coastguard Worker fn = jit_utils._get_py3_code(self.format_code(code, pair), 'foo') 12931*da0073e9SAndroid Build Coastguard Worker args = fn.schema.arguments 12932*da0073e9SAndroid Build Coastguard Worker returns = fn.schema.returns 12933*da0073e9SAndroid Build Coastguard Worker self.assertEqual(str(args[0].type), pair[1]) 12934*da0073e9SAndroid Build Coastguard Worker self.assertEqual(str(args[1].type), "Tuple[Tensor, Tensor]") 12935*da0073e9SAndroid Build Coastguard Worker self.assertEqual(str(returns[0].type), f"Tuple[{pair[1]}, {pair[1]}]") 12936*da0073e9SAndroid Build Coastguard Worker 12937*da0073e9SAndroid Build Coastguard Worker def test_bad_multiline_annotations(self): 12938*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Return type line"): 12939*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 12940*da0073e9SAndroid Build Coastguard Worker def bad_type_line(a, # type: Tensor 12941*da0073e9SAndroid Build Coastguard Worker b, # type: Tensor 12942*da0073e9SAndroid Build Coastguard Worker c # type: Tensor 12943*da0073e9SAndroid Build Coastguard Worker ): 12944*da0073e9SAndroid Build Coastguard Worker # type: (int, int, int) -> Tensor 12945*da0073e9SAndroid Build Coastguard Worker # type: bad type line # noqa: F723 12946*da0073e9SAndroid Build Coastguard Worker 12947*da0073e9SAndroid Build Coastguard Worker return a + b + c 12948*da0073e9SAndroid Build Coastguard Worker 12949*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Return type line"): 12950*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 12951*da0073e9SAndroid Build Coastguard Worker def bad_return_line(a, # type: Tensor 12952*da0073e9SAndroid Build Coastguard Worker b, 12953*da0073e9SAndroid Build Coastguard Worker c # type: Tensor 12954*da0073e9SAndroid Build Coastguard Worker ): 12955*da0073e9SAndroid Build Coastguard Worker # type: (int, int, int) -> Tensor 12956*da0073e9SAndroid Build Coastguard Worker return a + b + c 12957*da0073e9SAndroid Build Coastguard Worker 12958*da0073e9SAndroid Build Coastguard Worker # TODO: this should be supported but is difficult to parse 12959*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Number of type annotations"): 12960*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 12961*da0073e9SAndroid Build Coastguard Worker def missing_type(a, # type: Tensor 12962*da0073e9SAndroid Build Coastguard Worker b, 12963*da0073e9SAndroid Build Coastguard Worker c # type: Tensor 12964*da0073e9SAndroid Build Coastguard Worker ): 12965*da0073e9SAndroid Build Coastguard Worker # type: (...) -> Tensor 12966*da0073e9SAndroid Build Coastguard Worker return a + b + c 12967*da0073e9SAndroid Build Coastguard Worker 12968*da0073e9SAndroid Build Coastguard Worker # Python AST Frontend , Python 3-style type annotations , Script method 12969*da0073e9SAndroid Build Coastguard Worker def test_annot_ast_py3_method(self): 12970*da0073e9SAndroid Build Coastguard Worker code = dedent(''' 12971*da0073e9SAndroid Build Coastguard Worker from typing import Tuple, List, Optional 12972*da0073e9SAndroid Build Coastguard Worker from torch import Tensor 12973*da0073e9SAndroid Build Coastguard Worker from torch.jit.annotations import BroadcastingList2, \\ 12974*da0073e9SAndroid Build Coastguard Worker BroadcastingList3 12975*da0073e9SAndroid Build Coastguard Worker import torch 12976*da0073e9SAndroid Build Coastguard Worker class FooModule(torch.jit.ScriptModule): 12977*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 12978*da0073e9SAndroid Build Coastguard Worker def foo(self, x : {input}, y : Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}]: 12979*da0073e9SAndroid Build Coastguard Worker return x, x 12980*da0073e9SAndroid Build Coastguard Worker instance = FooModule() 12981*da0073e9SAndroid Build Coastguard Worker ''') 12982*da0073e9SAndroid Build Coastguard Worker 12983*da0073e9SAndroid Build Coastguard Worker test_str = [] 12984*da0073e9SAndroid Build Coastguard Worker for pair in self.type_input_return_pairs(): 12985*da0073e9SAndroid Build Coastguard Worker fn = jit_utils._get_py3_code(self.format_code(code, pair), 'instance') 12986*da0073e9SAndroid Build Coastguard Worker test_str.append(str(fn.foo.schema)) 12987*da0073e9SAndroid Build Coastguard Worker self.assertExpectedStripMangled("\n".join(test_str) + "\n") 12988*da0073e9SAndroid Build Coastguard Worker 12989*da0073e9SAndroid Build Coastguard Worker # Python AST Frontend , MyPy-style type comments , Script function 12990*da0073e9SAndroid Build Coastguard Worker def test_annot_ast_mypy_fn(self): 12991*da0073e9SAndroid Build Coastguard Worker code = dedent(''' 12992*da0073e9SAndroid Build Coastguard Worker import torch 12993*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 12994*da0073e9SAndroid Build Coastguard Worker def foo(x, y): 12995*da0073e9SAndroid Build Coastguard Worker # type: ({input}, Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}] 12996*da0073e9SAndroid Build Coastguard Worker return x, x 12997*da0073e9SAndroid Build Coastguard Worker ''') 12998*da0073e9SAndroid Build Coastguard Worker 12999*da0073e9SAndroid Build Coastguard Worker test_str = [] 13000*da0073e9SAndroid Build Coastguard Worker for pair in self.type_input_return_pairs(): 13001*da0073e9SAndroid Build Coastguard Worker fn = jit_utils._get_py3_code(self.format_code(code, pair), 'foo') 13002*da0073e9SAndroid Build Coastguard Worker test_str.append(str(fn.schema)) 13003*da0073e9SAndroid Build Coastguard Worker self.assertExpected("\n".join(test_str) + "\n") 13004*da0073e9SAndroid Build Coastguard Worker 13005*da0073e9SAndroid Build Coastguard Worker # Python AST Frontend , MyPy-style type comments , Script method 13006*da0073e9SAndroid Build Coastguard Worker def test_annot_ast_mypy_method(self): 13007*da0073e9SAndroid Build Coastguard Worker code = dedent(''' 13008*da0073e9SAndroid Build Coastguard Worker import torch 13009*da0073e9SAndroid Build Coastguard Worker class FooModule(torch.jit.ScriptModule): 13010*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 13011*da0073e9SAndroid Build Coastguard Worker def foo(self, x, y): 13012*da0073e9SAndroid Build Coastguard Worker # type: ({input}, Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}] 13013*da0073e9SAndroid Build Coastguard Worker return x, x 13014*da0073e9SAndroid Build Coastguard Worker instance = FooModule() 13015*da0073e9SAndroid Build Coastguard Worker ''') 13016*da0073e9SAndroid Build Coastguard Worker 13017*da0073e9SAndroid Build Coastguard Worker test_str = [] 13018*da0073e9SAndroid Build Coastguard Worker for pair in self.type_input_return_pairs(): 13019*da0073e9SAndroid Build Coastguard Worker fn = jit_utils._get_py3_code(self.format_code(code, pair), 'instance') 13020*da0073e9SAndroid Build Coastguard Worker test_str.append(str(fn.foo.schema)) 13021*da0073e9SAndroid Build Coastguard Worker self.assertExpectedStripMangled("\n".join(test_str) + "\n") 13022*da0073e9SAndroid Build Coastguard Worker 13023*da0073e9SAndroid Build Coastguard Worker # Tests that "# type: ignore[*]" is supported in type lines and is 13024*da0073e9SAndroid Build Coastguard Worker # properly ignored. 13025*da0073e9SAndroid Build Coastguard Worker def test_mypy_type_ignore(self): 13026*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 13027*da0073e9SAndroid Build Coastguard Worker def foo(x): # type: ignore 13028*da0073e9SAndroid Build Coastguard Worker return x 13029*da0073e9SAndroid Build Coastguard Worker 13030*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 13031*da0073e9SAndroid Build Coastguard Worker def bar(x): # type: ignore[no-redef] 13032*da0073e9SAndroid Build Coastguard Worker return x 13033*da0073e9SAndroid Build Coastguard Worker 13034*da0073e9SAndroid Build Coastguard Worker def test_method_casts_script(self): 13035*da0073e9SAndroid Build Coastguard Worker cast_types = [ 13036*da0073e9SAndroid Build Coastguard Worker 'byte', 'char', 'double', 'float', 'int', 'long', 'short' 13037*da0073e9SAndroid Build Coastguard Worker ] 13038*da0073e9SAndroid Build Coastguard Worker 13039*da0073e9SAndroid Build Coastguard Worker for cast_type in cast_types: 13040*da0073e9SAndroid Build Coastguard Worker cu = torch.jit.CompilationUnit(f''' 13041*da0073e9SAndroid Build Coastguard Worker def cast_to(x): 13042*da0073e9SAndroid Build Coastguard Worker return x.{cast_type}() 13043*da0073e9SAndroid Build Coastguard Worker ''') 13044*da0073e9SAndroid Build Coastguard Worker 13045*da0073e9SAndroid Build Coastguard Worker x = torch.rand(3, 4, 5) * 128 13046*da0073e9SAndroid Build Coastguard Worker cu_result = cu.cast_to(x) 13047*da0073e9SAndroid Build Coastguard Worker reference = getattr(x, cast_type)() 13048*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cu_result, reference) 13049*da0073e9SAndroid Build Coastguard Worker 13050*da0073e9SAndroid Build Coastguard Worker def test_string_frontend_elif(self): 13051*da0073e9SAndroid Build Coastguard Worker code = ''' 13052*da0073e9SAndroid Build Coastguard Worker def func(niter): 13053*da0073e9SAndroid Build Coastguard Worker # type: (int) 13054*da0073e9SAndroid Build Coastguard Worker rv = 0 13055*da0073e9SAndroid Build Coastguard Worker for i in range(niter): 13056*da0073e9SAndroid Build Coastguard Worker if i % 3 == 0 and i % 5 == 0: 13057*da0073e9SAndroid Build Coastguard Worker rv += 35 13058*da0073e9SAndroid Build Coastguard Worker elif i % 3 == 0: 13059*da0073e9SAndroid Build Coastguard Worker rv += 3 13060*da0073e9SAndroid Build Coastguard Worker elif i % 5 == 0: 13061*da0073e9SAndroid Build Coastguard Worker rv += 5 13062*da0073e9SAndroid Build Coastguard Worker else: 13063*da0073e9SAndroid Build Coastguard Worker rv += i 13064*da0073e9SAndroid Build Coastguard Worker return rv 13065*da0073e9SAndroid Build Coastguard Worker ''' 13066*da0073e9SAndroid Build Coastguard Worker 13067*da0073e9SAndroid Build Coastguard Worker self.checkScript(dedent(code), (101,)) 13068*da0073e9SAndroid Build Coastguard Worker 13069*da0073e9SAndroid Build Coastguard Worker def test_module_parameters_and_buffers(self): 13070*da0073e9SAndroid Build Coastguard Worker weights = torch.randn(10, 10) 13071*da0073e9SAndroid Build Coastguard Worker bias = torch.randn(10) 13072*da0073e9SAndroid Build Coastguard Worker weights2 = torch.randn(10, 10) 13073*da0073e9SAndroid Build Coastguard Worker bias2 = torch.randn(10) 13074*da0073e9SAndroid Build Coastguard Worker 13075*da0073e9SAndroid Build Coastguard Worker class TestLinear(torch.nn.Module): 13076*da0073e9SAndroid Build Coastguard Worker def __init__(self, in_features, out_features): 13077*da0073e9SAndroid Build Coastguard Worker super().__init__() 13078*da0073e9SAndroid Build Coastguard Worker self.in_features = in_features 13079*da0073e9SAndroid Build Coastguard Worker self.out_features = out_features 13080*da0073e9SAndroid Build Coastguard Worker self.weight = torch.nn.Parameter(torch.empty(out_features, in_features)) 13081*da0073e9SAndroid Build Coastguard Worker self.bias = torch.nn.Parameter(torch.empty(out_features)) 13082*da0073e9SAndroid Build Coastguard Worker self.counter = nn.Buffer(torch.ones(out_features)) 13083*da0073e9SAndroid Build Coastguard Worker self.reset_parameters() 13084*da0073e9SAndroid Build Coastguard Worker 13085*da0073e9SAndroid Build Coastguard Worker def reset_parameters(self): 13086*da0073e9SAndroid Build Coastguard Worker torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) 13087*da0073e9SAndroid Build Coastguard Worker if self.bias is not None: 13088*da0073e9SAndroid Build Coastguard Worker fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight) 13089*da0073e9SAndroid Build Coastguard Worker bound = 1 / math.sqrt(fan_in) 13090*da0073e9SAndroid Build Coastguard Worker torch.nn.init.uniform_(self.bias, -bound, bound) 13091*da0073e9SAndroid Build Coastguard Worker 13092*da0073e9SAndroid Build Coastguard Worker def forward(self, input): 13093*da0073e9SAndroid Build Coastguard Worker return F.linear(input, self.weight, self.bias) + self.counter 13094*da0073e9SAndroid Build Coastguard Worker 13095*da0073e9SAndroid Build Coastguard Worker # Initialize a ScriptModule that uses the weak module above multiple times 13096*da0073e9SAndroid Build Coastguard Worker class Strong(torch.jit.ScriptModule): 13097*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 13098*da0073e9SAndroid Build Coastguard Worker super().__init__() 13099*da0073e9SAndroid Build Coastguard Worker self.fc1 = TestLinear(10, 10) 13100*da0073e9SAndroid Build Coastguard Worker self.fc1.weight = torch.nn.Parameter(weights) 13101*da0073e9SAndroid Build Coastguard Worker self.fc1.bias = torch.nn.Parameter(bias) 13102*da0073e9SAndroid Build Coastguard Worker self.fc2 = TestLinear(10, 10) 13103*da0073e9SAndroid Build Coastguard Worker self.fc2.weight = torch.nn.Parameter(weights2) 13104*da0073e9SAndroid Build Coastguard Worker self.fc2.bias = torch.nn.Parameter(bias2) 13105*da0073e9SAndroid Build Coastguard Worker 13106*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 13107*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 13108*da0073e9SAndroid Build Coastguard Worker return x + self.fc1(x) + self.fc1(x) + self.fc2(x) 13109*da0073e9SAndroid Build Coastguard Worker 13110*da0073e9SAndroid Build Coastguard Worker strong_mod = Strong() 13111*da0073e9SAndroid Build Coastguard Worker 13112*da0073e9SAndroid Build Coastguard Worker # Run same calculation as module 13113*da0073e9SAndroid Build Coastguard Worker inp = torch.ones(10) 13114*da0073e9SAndroid Build Coastguard Worker lin = torch.nn.Linear(10, 10) 13115*da0073e9SAndroid Build Coastguard Worker lin.weight = torch.nn.Parameter(weights) 13116*da0073e9SAndroid Build Coastguard Worker lin.bias = torch.nn.Parameter(bias) 13117*da0073e9SAndroid Build Coastguard Worker lin2 = torch.nn.Linear(10, 10) 13118*da0073e9SAndroid Build Coastguard Worker lin2.weight = torch.nn.Parameter(weights2) 13119*da0073e9SAndroid Build Coastguard Worker lin2.bias = torch.nn.Parameter(bias2) 13120*da0073e9SAndroid Build Coastguard Worker expected_result = inp + (lin(inp) + torch.ones(10)) * 2 + lin2(inp) + torch.ones(10) 13121*da0073e9SAndroid Build Coastguard Worker 13122*da0073e9SAndroid Build Coastguard Worker self.assertEqual(strong_mod(inp), expected_result) 13123*da0073e9SAndroid Build Coastguard Worker self.assertExportImportModule(strong_mod, (inp,)) 13124*da0073e9SAndroid Build Coastguard Worker 13125*da0073e9SAndroid Build Coastguard Worker def test_module_copying(self): 13126*da0073e9SAndroid Build Coastguard Worker class Submodule(torch.nn.Module): 13127*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 13128*da0073e9SAndroid Build Coastguard Worker return x + 100 13129*da0073e9SAndroid Build Coastguard Worker 13130*da0073e9SAndroid Build Coastguard Worker class Weak(torch.nn.Module): 13131*da0073e9SAndroid Build Coastguard Worker def __init__(self, in_features, out_features): 13132*da0073e9SAndroid Build Coastguard Worker super().__init__() 13133*da0073e9SAndroid Build Coastguard Worker self.weight = torch.nn.Parameter(torch.ones(out_features, in_features)) 13134*da0073e9SAndroid Build Coastguard Worker self.bias = torch.nn.Parameter(torch.ones(out_features)) 13135*da0073e9SAndroid Build Coastguard Worker self.buffer = nn.Buffer(torch.ones(out_features)) 13136*da0073e9SAndroid Build Coastguard Worker self.submodule = Submodule() 13137*da0073e9SAndroid Build Coastguard Worker 13138*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 13139*da0073e9SAndroid Build Coastguard Worker return F.linear(x, self.weight, self.bias) \ 13140*da0073e9SAndroid Build Coastguard Worker + self.buffer + self.submodule(x) 13141*da0073e9SAndroid Build Coastguard Worker 13142*da0073e9SAndroid Build Coastguard Worker class Strong(torch.jit.ScriptModule): 13143*da0073e9SAndroid Build Coastguard Worker def __init__(self, weak): 13144*da0073e9SAndroid Build Coastguard Worker super().__init__() 13145*da0073e9SAndroid Build Coastguard Worker self.weak = weak 13146*da0073e9SAndroid Build Coastguard Worker 13147*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 13148*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 13149*da0073e9SAndroid Build Coastguard Worker return self.weak(x) 13150*da0073e9SAndroid Build Coastguard Worker 13151*da0073e9SAndroid Build Coastguard Worker inp = torch.ones(5, 5) * 5 13152*da0073e9SAndroid Build Coastguard Worker weak_mod = Weak(5, 5) 13153*da0073e9SAndroid Build Coastguard Worker strong_mod = Strong(weak_mod) 13154*da0073e9SAndroid Build Coastguard Worker 13155*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(strong_mod.weak, torch.jit.ScriptModule)) 13156*da0073e9SAndroid Build Coastguard Worker self.assertFalse(isinstance(weak_mod, torch.jit.ScriptModule)) 13157*da0073e9SAndroid Build Coastguard Worker 13158*da0073e9SAndroid Build Coastguard Worker self.assertIs(strong_mod.weak.weight, weak_mod.weight) 13159*da0073e9SAndroid Build Coastguard Worker self.assertIs(strong_mod.weak.buffer, weak_mod.buffer) 13160*da0073e9SAndroid Build Coastguard Worker # strong_mod.weak.submodule has been recursively scripted 13161*da0073e9SAndroid Build Coastguard Worker self.assertIsNot(strong_mod.weak.submodule, weak_mod.submodule) 13162*da0073e9SAndroid Build Coastguard Worker 13163*da0073e9SAndroid Build Coastguard Worker weak_mod.weight.data += torch.ones(5, 5) * 100 13164*da0073e9SAndroid Build Coastguard Worker self.assertTrue(strong_mod(inp).allclose(weak_mod(inp))) 13165*da0073e9SAndroid Build Coastguard Worker 13166*da0073e9SAndroid Build Coastguard Worker # Re-assignment is not tracked 13167*da0073e9SAndroid Build Coastguard Worker weak_mod.weight = torch.nn.Parameter(torch.ones(5, 5) * 100) 13168*da0073e9SAndroid Build Coastguard Worker self.assertFalse(strong_mod(inp).allclose(weak_mod(inp))) 13169*da0073e9SAndroid Build Coastguard Worker 13170*da0073e9SAndroid Build Coastguard Worker def test_backend_cudnn_enabled(self): 13171*da0073e9SAndroid Build Coastguard Worker # Only test that this compiles 13172*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 13173*da0073e9SAndroid Build Coastguard Worker def fn(x): 13174*da0073e9SAndroid Build Coastguard Worker if torch.backends.cudnn.enabled: 13175*da0073e9SAndroid Build Coastguard Worker x = x + 2 13176*da0073e9SAndroid Build Coastguard Worker else: 13177*da0073e9SAndroid Build Coastguard Worker x = x + 3 13178*da0073e9SAndroid Build Coastguard Worker return x 13179*da0073e9SAndroid Build Coastguard Worker 13180*da0073e9SAndroid Build Coastguard Worker def test_inplace_add(self): 13181*da0073e9SAndroid Build Coastguard Worker 13182*da0073e9SAndroid Build Coastguard Worker def foo(a, b): 13183*da0073e9SAndroid Build Coastguard Worker c = a + b 13184*da0073e9SAndroid Build Coastguard Worker c.add_(b) 13185*da0073e9SAndroid Build Coastguard Worker return c 13186*da0073e9SAndroid Build Coastguard Worker self.checkScript(foo, (torch.rand(3), torch.rand(3))) 13187*da0073e9SAndroid Build Coastguard Worker 13188*da0073e9SAndroid Build Coastguard Worker def test_add_out(self): 13189*da0073e9SAndroid Build Coastguard Worker def foo(a, b): 13190*da0073e9SAndroid Build Coastguard Worker c = a + b 13191*da0073e9SAndroid Build Coastguard Worker e = 2 * a 13192*da0073e9SAndroid Build Coastguard Worker torch.add(c, b, out=e) 13193*da0073e9SAndroid Build Coastguard Worker return e 13194*da0073e9SAndroid Build Coastguard Worker self.checkScript(foo, (torch.rand(3), torch.rand(3))) 13195*da0073e9SAndroid Build Coastguard Worker 13196*da0073e9SAndroid Build Coastguard Worker def test_tuple_error_msg(self): 13197*da0073e9SAndroid Build Coastguard Worker def fn(t: Any): 13198*da0073e9SAndroid Build Coastguard Worker if isinstance(t, tuple): 13199*da0073e9SAndroid Build Coastguard Worker a, b = t 13200*da0073e9SAndroid Build Coastguard Worker return a + b 13201*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegexWithHighlight(RuntimeError, "Provided tuple is not fully defined/refined", "t"): 13202*da0073e9SAndroid Build Coastguard Worker s = torch.jit.script(fn) 13203*da0073e9SAndroid Build Coastguard Worker 13204*da0073e9SAndroid Build Coastguard Worker def test_augmented_assign(self): 13205*da0073e9SAndroid Build Coastguard Worker def foo(a, b): 13206*da0073e9SAndroid Build Coastguard Worker a += b 13207*da0073e9SAndroid Build Coastguard Worker a -= b 13208*da0073e9SAndroid Build Coastguard Worker a /= b 13209*da0073e9SAndroid Build Coastguard Worker a *= b 13210*da0073e9SAndroid Build Coastguard Worker return a, b 13211*da0073e9SAndroid Build Coastguard Worker self.checkScript(foo, (torch.rand(3), torch.rand(3))) 13212*da0073e9SAndroid Build Coastguard Worker 13213*da0073e9SAndroid Build Coastguard Worker def test_ignored_props(self): 13214*da0073e9SAndroid Build Coastguard Worker class A(nn.Module): 13215*da0073e9SAndroid Build Coastguard Worker __jit_ignored_attributes__ = ["ignored", "ignored_return_val"] 13216*da0073e9SAndroid Build Coastguard Worker 13217*da0073e9SAndroid Build Coastguard Worker @property 13218*da0073e9SAndroid Build Coastguard Worker def ignored(self): 13219*da0073e9SAndroid Build Coastguard Worker raise ValueError("shouldn't be called") 13220*da0073e9SAndroid Build Coastguard Worker 13221*da0073e9SAndroid Build Coastguard Worker @property 13222*da0073e9SAndroid Build Coastguard Worker def ignored_return_val(self): 13223*da0073e9SAndroid Build Coastguard Worker return 1 13224*da0073e9SAndroid Build Coastguard Worker 13225*da0073e9SAndroid Build Coastguard Worker @torch.jit.ignore 13226*da0073e9SAndroid Build Coastguard Worker def call(self): 13227*da0073e9SAndroid Build Coastguard Worker return self.ignored_return_val 13228*da0073e9SAndroid Build Coastguard Worker 13229*da0073e9SAndroid Build Coastguard Worker f = torch.jit.script(A()) 13230*da0073e9SAndroid Build Coastguard Worker # jank way to test if there is no error 13231*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(f, torch.jit.ScriptModule)) 13232*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(f.call(), property)) 13233*da0073e9SAndroid Build Coastguard Worker 13234*da0073e9SAndroid Build Coastguard Worker 13235*da0073e9SAndroid Build Coastguard Worker def test_pass(self): 13236*da0073e9SAndroid Build Coastguard Worker def foo(x): 13237*da0073e9SAndroid Build Coastguard Worker # type: (bool) -> int 13238*da0073e9SAndroid Build Coastguard Worker for _i in range(3): 13239*da0073e9SAndroid Build Coastguard Worker pass 13240*da0073e9SAndroid Build Coastguard Worker if x: 13241*da0073e9SAndroid Build Coastguard Worker pass 13242*da0073e9SAndroid Build Coastguard Worker else: 13243*da0073e9SAndroid Build Coastguard Worker pass 13244*da0073e9SAndroid Build Coastguard Worker return 3 13245*da0073e9SAndroid Build Coastguard Worker 13246*da0073e9SAndroid Build Coastguard Worker self.checkScript(foo, (True,)) 13247*da0073e9SAndroid Build Coastguard Worker 13248*da0073e9SAndroid Build Coastguard Worker def test_lhs_indexing(self): 13249*da0073e9SAndroid Build Coastguard Worker def foo(a, b): 13250*da0073e9SAndroid Build Coastguard Worker a = a.clone() 13251*da0073e9SAndroid Build Coastguard Worker a[0] = b 13252*da0073e9SAndroid Build Coastguard Worker return a 13253*da0073e9SAndroid Build Coastguard Worker self.checkScript(foo, (torch.rand(2, 3), torch.rand(3))) 13254*da0073e9SAndroid Build Coastguard Worker 13255*da0073e9SAndroid Build Coastguard Worker def test_lhs_advanced_indexing_assignment(self): 13256*da0073e9SAndroid Build Coastguard Worker def foo(x, y): 13257*da0073e9SAndroid Build Coastguard Worker a = torch.exp(x) 13258*da0073e9SAndroid Build Coastguard Worker b = x == 1 13259*da0073e9SAndroid Build Coastguard Worker a[b] = y[b] 13260*da0073e9SAndroid Build Coastguard Worker return a 13261*da0073e9SAndroid Build Coastguard Worker self.checkScript(foo, (torch.ones(4, 3), torch.ones(4, 3))) 13262*da0073e9SAndroid Build Coastguard Worker 13263*da0073e9SAndroid Build Coastguard Worker def test_lhs_advanced_indexing_augmented_assignment(self): 13264*da0073e9SAndroid Build Coastguard Worker def foo(x, y): 13265*da0073e9SAndroid Build Coastguard Worker a = torch.exp(x) 13266*da0073e9SAndroid Build Coastguard Worker b = x == 1 13267*da0073e9SAndroid Build Coastguard Worker a[b] += y[b] 13268*da0073e9SAndroid Build Coastguard Worker return a 13269*da0073e9SAndroid Build Coastguard Worker self.checkScript(foo, (torch.ones(4, 3), torch.ones(4, 3))) 13270*da0073e9SAndroid Build Coastguard Worker 13271*da0073e9SAndroid Build Coastguard Worker def test_lhs_indexing_list(self): 13272*da0073e9SAndroid Build Coastguard Worker def foo(a, b): 13273*da0073e9SAndroid Build Coastguard Worker ls = [a] 13274*da0073e9SAndroid Build Coastguard Worker ls[0] = b 13275*da0073e9SAndroid Build Coastguard Worker return ls 13276*da0073e9SAndroid Build Coastguard Worker self.checkScript(foo, (torch.rand(2, 3), torch.rand(3))) 13277*da0073e9SAndroid Build Coastguard Worker 13278*da0073e9SAndroid Build Coastguard Worker def test_inplace_copy_script(self): 13279*da0073e9SAndroid Build Coastguard Worker def foo(x): 13280*da0073e9SAndroid Build Coastguard Worker a = torch.rand(3, 4) 13281*da0073e9SAndroid Build Coastguard Worker a.copy_(x) 13282*da0073e9SAndroid Build Coastguard Worker return a 13283*da0073e9SAndroid Build Coastguard Worker self.checkScript(foo, (torch.rand(3, 4),)) 13284*da0073e9SAndroid Build Coastguard Worker 13285*da0073e9SAndroid Build Coastguard Worker def test_lhs_indexing_increment(self): 13286*da0073e9SAndroid Build Coastguard Worker def foo(a, b): 13287*da0073e9SAndroid Build Coastguard Worker a[0] += b 13288*da0073e9SAndroid Build Coastguard Worker return a 13289*da0073e9SAndroid Build Coastguard Worker self.checkScript(foo, (torch.rand(2, 3), torch.rand(3))) 13290*da0073e9SAndroid Build Coastguard Worker 13291*da0073e9SAndroid Build Coastguard Worker def test_lhs_indexing_increment_list(self): 13292*da0073e9SAndroid Build Coastguard Worker def foo(a, b): 13293*da0073e9SAndroid Build Coastguard Worker a = a.clone() 13294*da0073e9SAndroid Build Coastguard Worker ls = [a, b] 13295*da0073e9SAndroid Build Coastguard Worker ls[0] += b 13296*da0073e9SAndroid Build Coastguard Worker return ls 13297*da0073e9SAndroid Build Coastguard Worker self.checkScript(foo, (torch.rand(2, 3), torch.rand(3))) 13298*da0073e9SAndroid Build Coastguard Worker 13299*da0073e9SAndroid Build Coastguard Worker def test_lhs_indexing_increment_list_prim(self): 13300*da0073e9SAndroid Build Coastguard Worker def foo(): 13301*da0073e9SAndroid Build Coastguard Worker ls = [1, 2, 3] 13302*da0073e9SAndroid Build Coastguard Worker ls[0] += 5 13303*da0073e9SAndroid Build Coastguard Worker return ls 13304*da0073e9SAndroid Build Coastguard Worker self.checkScript(foo, ()) 13305*da0073e9SAndroid Build Coastguard Worker 13306*da0073e9SAndroid Build Coastguard Worker def test_lhs_indexing_multi(self): 13307*da0073e9SAndroid Build Coastguard Worker def foo(a, b): 13308*da0073e9SAndroid Build Coastguard Worker a = a.clone() 13309*da0073e9SAndroid Build Coastguard Worker foo, a[0], bar = (1, b, 3) 13310*da0073e9SAndroid Build Coastguard Worker return foo, a, bar 13311*da0073e9SAndroid Build Coastguard Worker self.checkScript(foo, (torch.rand(2, 3), torch.rand(3))) 13312*da0073e9SAndroid Build Coastguard Worker 13313*da0073e9SAndroid Build Coastguard Worker def test_bool_dispatch(self): 13314*da0073e9SAndroid Build Coastguard Worker with torch._jit_internal._disable_emit_hooks(): # TODO: Python print broadcasting list 13315*da0073e9SAndroid Build Coastguard Worker def kwarg_false(x): 13316*da0073e9SAndroid Build Coastguard Worker # type: (Tensor) -> Tensor 13317*da0073e9SAndroid Build Coastguard Worker return F.max_pool1d(x, 1, 1, return_indices=False) 13318*da0073e9SAndroid Build Coastguard Worker self.checkScript(kwarg_false, (torch.randn(3, 3, 3),)) 13319*da0073e9SAndroid Build Coastguard Worker 13320*da0073e9SAndroid Build Coastguard Worker def kwarg_true(x): 13321*da0073e9SAndroid Build Coastguard Worker # type: (Tensor) -> Tuple[Tensor, Tensor] 13322*da0073e9SAndroid Build Coastguard Worker return F.max_pool1d(x, 1, 1, return_indices=True) 13323*da0073e9SAndroid Build Coastguard Worker self.checkScript(kwarg_true, (torch.randn(3, 3, 3),)) 13324*da0073e9SAndroid Build Coastguard Worker 13325*da0073e9SAndroid Build Coastguard Worker def full_kwarg_false(x): 13326*da0073e9SAndroid Build Coastguard Worker # type: (Tensor) -> Tensor 13327*da0073e9SAndroid Build Coastguard Worker return F.max_pool1d(x, 1, 1, ceil_mode=False, return_indices=False) 13328*da0073e9SAndroid Build Coastguard Worker self.checkScript(full_kwarg_false, (torch.randn(3, 3, 3),)) 13329*da0073e9SAndroid Build Coastguard Worker 13330*da0073e9SAndroid Build Coastguard Worker def full_kwarg_true(x): 13331*da0073e9SAndroid Build Coastguard Worker # type: (Tensor) -> Tuple[Tensor, Tensor] 13332*da0073e9SAndroid Build Coastguard Worker return F.max_pool1d(x, 1, 1, ceil_mode=False, return_indices=True) 13333*da0073e9SAndroid Build Coastguard Worker self.checkScript(full_kwarg_true, (torch.randn(3, 3, 3),)) 13334*da0073e9SAndroid Build Coastguard Worker 13335*da0073e9SAndroid Build Coastguard Worker def use_default(x): 13336*da0073e9SAndroid Build Coastguard Worker # type: (Tensor) -> Tensor 13337*da0073e9SAndroid Build Coastguard Worker return F.max_pool1d(x, 1, 1) 13338*da0073e9SAndroid Build Coastguard Worker self.checkScript(use_default, (torch.randn(3, 3, 3),)) 13339*da0073e9SAndroid Build Coastguard Worker 13340*da0073e9SAndroid Build Coastguard Worker def arg_false(x): 13341*da0073e9SAndroid Build Coastguard Worker # type: (Tensor) -> Tensor 13342*da0073e9SAndroid Build Coastguard Worker return F.max_pool1d(x, 1, 1, 0, 1, False, False) 13343*da0073e9SAndroid Build Coastguard Worker self.checkScript(arg_false, (torch.randn(3, 3, 3),)) 13344*da0073e9SAndroid Build Coastguard Worker 13345*da0073e9SAndroid Build Coastguard Worker def arg_true(x): 13346*da0073e9SAndroid Build Coastguard Worker # type: (Tensor) -> Tuple[Tensor, Tensor] 13347*da0073e9SAndroid Build Coastguard Worker return F.max_pool1d(x, 1, 1, 0, 1, False, True) 13348*da0073e9SAndroid Build Coastguard Worker self.checkScript(arg_true, (torch.randn(3, 3, 3),)) 13349*da0073e9SAndroid Build Coastguard Worker 13350*da0073e9SAndroid Build Coastguard Worker def test_infer_size(self): 13351*da0073e9SAndroid Build Coastguard Worker from torch._C import _infer_size 13352*da0073e9SAndroid Build Coastguard Worker 13353*da0073e9SAndroid Build Coastguard Worker def fn(x, y): 13354*da0073e9SAndroid Build Coastguard Worker # type: (Tensor, Tensor) -> List[int] 13355*da0073e9SAndroid Build Coastguard Worker return _infer_size(x.size(), y.size()) 13356*da0073e9SAndroid Build Coastguard Worker 13357*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, (torch.ones(2, 4, 2), torch.ones(2, 4, 2))) 13358*da0073e9SAndroid Build Coastguard Worker 13359*da0073e9SAndroid Build Coastguard Worker def test_hash(self): 13360*da0073e9SAndroid Build Coastguard Worker def tester(fn, inputs): 13361*da0073e9SAndroid Build Coastguard Worker for x in inputs: 13362*da0073e9SAndroid Build Coastguard Worker for y in inputs: 13363*da0073e9SAndroid Build Coastguard Worker if x == y: 13364*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn(x), fn(y)) 13365*da0073e9SAndroid Build Coastguard Worker else: 13366*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(fn(x), fn(y)) 13367*da0073e9SAndroid Build Coastguard Worker 13368*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 13369*da0073e9SAndroid Build Coastguard Worker def int_hash(x): 13370*da0073e9SAndroid Build Coastguard Worker # type: (int) -> int 13371*da0073e9SAndroid Build Coastguard Worker return hash(x) 13372*da0073e9SAndroid Build Coastguard Worker 13373*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 13374*da0073e9SAndroid Build Coastguard Worker def float_hash(x): 13375*da0073e9SAndroid Build Coastguard Worker # type: (float) -> int 13376*da0073e9SAndroid Build Coastguard Worker return hash(x) 13377*da0073e9SAndroid Build Coastguard Worker 13378*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 13379*da0073e9SAndroid Build Coastguard Worker def str_hash(x): 13380*da0073e9SAndroid Build Coastguard Worker # type: (str) -> int 13381*da0073e9SAndroid Build Coastguard Worker return hash(x) 13382*da0073e9SAndroid Build Coastguard Worker 13383*da0073e9SAndroid Build Coastguard Worker tester(int_hash, (20, 21, 22)) 13384*da0073e9SAndroid Build Coastguard Worker tester(float_hash, (20.0, 21.00001, 22.443)) 13385*da0073e9SAndroid Build Coastguard Worker tester(str_hash, ("", "hello", "a")) 13386*da0073e9SAndroid Build Coastguard Worker 13387*da0073e9SAndroid Build Coastguard Worker def test_id(self): 13388*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Expected a value"): 13389*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 13390*da0073e9SAndroid Build Coastguard Worker def test_id_scalars(): 13391*da0073e9SAndroid Build Coastguard Worker return id(2) == id(None) 13392*da0073e9SAndroid Build Coastguard Worker 13393*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 13394*da0073e9SAndroid Build Coastguard Worker class FooTest: 13395*da0073e9SAndroid Build Coastguard Worker def __init__(self, x): 13396*da0073e9SAndroid Build Coastguard Worker self.foo = x 13397*da0073e9SAndroid Build Coastguard Worker 13398*da0073e9SAndroid Build Coastguard Worker def getFooTest(self): 13399*da0073e9SAndroid Build Coastguard Worker return self.foo 13400*da0073e9SAndroid Build Coastguard Worker 13401*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 13402*da0073e9SAndroid Build Coastguard Worker def test_id_class_types(): 13403*da0073e9SAndroid Build Coastguard Worker obj1 = FooTest(torch.tensor(3)) 13404*da0073e9SAndroid Build Coastguard Worker obj2 = FooTest(torch.tensor(2)) 13405*da0073e9SAndroid Build Coastguard Worker assert obj1 is not obj2 13406*da0073e9SAndroid Build Coastguard Worker assert id(obj1) != id(obj2) 13407*da0073e9SAndroid Build Coastguard Worker assert id(obj1) != id(None) 13408*da0073e9SAndroid Build Coastguard Worker return True 13409*da0073e9SAndroid Build Coastguard Worker 13410*da0073e9SAndroid Build Coastguard Worker self.assertTrue(test_id_class_types()) 13411*da0073e9SAndroid Build Coastguard Worker 13412*da0073e9SAndroid Build Coastguard Worker def test_mutable_dce(self): 13413*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 13414*da0073e9SAndroid Build Coastguard Worker def foo(): 13415*da0073e9SAndroid Build Coastguard Worker a = torch.rand(2, 3) 13416*da0073e9SAndroid Build Coastguard Worker a += torch.rand(2, 3) 13417*da0073e9SAndroid Build Coastguard Worker b = torch.rand(2, 3) 13418*da0073e9SAndroid Build Coastguard Worker b += torch.rand(2, 3) 13419*da0073e9SAndroid Build Coastguard Worker # b should be cleaned up but not a 13420*da0073e9SAndroid Build Coastguard Worker return a 13421*da0073e9SAndroid Build Coastguard Worker 13422*da0073e9SAndroid Build Coastguard Worker FileCheck().check_count("aten::rand", 2, exactly=True) \ 13423*da0073e9SAndroid Build Coastguard Worker .check_count("aten::add", 1, exactly=True).run(str(foo.graph)) 13424*da0073e9SAndroid Build Coastguard Worker 13425*da0073e9SAndroid Build Coastguard Worker def test_mutable_dce_block(self): 13426*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 13427*da0073e9SAndroid Build Coastguard Worker def foo(): 13428*da0073e9SAndroid Build Coastguard Worker a = torch.rand(2, 3) 13429*da0073e9SAndroid Build Coastguard Worker a += torch.rand(2, 3) 13430*da0073e9SAndroid Build Coastguard Worker b = torch.rand(2, 3) 13431*da0073e9SAndroid Build Coastguard Worker if bool(a > torch.zeros(2, 3)): 13432*da0073e9SAndroid Build Coastguard Worker b += torch.rand(2, 3) 13433*da0073e9SAndroid Build Coastguard Worker a += torch.rand(2, 3) 13434*da0073e9SAndroid Build Coastguard Worker # a should be cleaned up but not b 13435*da0073e9SAndroid Build Coastguard Worker return b 13436*da0073e9SAndroid Build Coastguard Worker 13437*da0073e9SAndroid Build Coastguard Worker FileCheck().check("prim::If").check_count("aten::rand", 1, exactly=True) \ 13438*da0073e9SAndroid Build Coastguard Worker .run(str(foo.graph)) 13439*da0073e9SAndroid Build Coastguard Worker 13440*da0073e9SAndroid Build Coastguard Worker def test_mutable_dce_graph_input(self): 13441*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 13442*da0073e9SAndroid Build Coastguard Worker def foo(a): 13443*da0073e9SAndroid Build Coastguard Worker a += torch.rand(2, 3) 13444*da0073e9SAndroid Build Coastguard Worker # shouldn't clean up `a` even though it's not used in the output 13445*da0073e9SAndroid Build Coastguard Worker 13446*da0073e9SAndroid Build Coastguard Worker FileCheck().check("aten::rand").check("aten::add").run(str(foo.graph)) 13447*da0073e9SAndroid Build Coastguard Worker 13448*da0073e9SAndroid Build Coastguard Worker def test_mutable_dce_list(self): 13449*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 13450*da0073e9SAndroid Build Coastguard Worker def foo(a): 13451*da0073e9SAndroid Build Coastguard Worker l = [] 13452*da0073e9SAndroid Build Coastguard Worker l.append(a) 13453*da0073e9SAndroid Build Coastguard Worker c = l[0] 13454*da0073e9SAndroid Build Coastguard Worker b = torch.rand(2, 3) 13455*da0073e9SAndroid Build Coastguard Worker c += torch.rand(2, 3) 13456*da0073e9SAndroid Build Coastguard Worker return b 13457*da0073e9SAndroid Build Coastguard Worker 13458*da0073e9SAndroid Build Coastguard Worker # c does not get cleaned up because there is a wildcard + mutation 13459*da0073e9SAndroid Build Coastguard Worker FileCheck().check_count("aten::rand", 2, exactly=True).run(str(foo.graph)) 13460*da0073e9SAndroid Build Coastguard Worker 13461*da0073e9SAndroid Build Coastguard Worker def test_mutable_dce_loop(self): 13462*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 13463*da0073e9SAndroid Build Coastguard Worker def foo(a): 13464*da0073e9SAndroid Build Coastguard Worker l = [] 13465*da0073e9SAndroid Build Coastguard Worker l.append(a) 13466*da0073e9SAndroid Build Coastguard Worker i = 0 13467*da0073e9SAndroid Build Coastguard Worker b = torch.rand(2, 3) 13468*da0073e9SAndroid Build Coastguard Worker while i < 1: 13469*da0073e9SAndroid Build Coastguard Worker dead = torch.rand(2, 3) 13470*da0073e9SAndroid Build Coastguard Worker c = l[0] 13471*da0073e9SAndroid Build Coastguard Worker c += torch.rand(2, 3) 13472*da0073e9SAndroid Build Coastguard Worker i += 1 13473*da0073e9SAndroid Build Coastguard Worker return b 13474*da0073e9SAndroid Build Coastguard Worker 13475*da0073e9SAndroid Build Coastguard Worker FileCheck().check("prim::Loop").check_not("aten::rand").check("aten::__getitem__") \ 13476*da0073e9SAndroid Build Coastguard Worker .check_count("aten::rand", 1, exactly=True).run(str(foo.graph)) 13477*da0073e9SAndroid Build Coastguard Worker 13478*da0073e9SAndroid Build Coastguard Worker def test_mutable_dce_indirect_wildcards(self): 13479*da0073e9SAndroid Build Coastguard Worker def fn(): 13480*da0073e9SAndroid Build Coastguard Worker x = torch.ones(2, 3) 13481*da0073e9SAndroid Build Coastguard Worker x_1 = x.view(-1) 13482*da0073e9SAndroid Build Coastguard Worker l = [] 13483*da0073e9SAndroid Build Coastguard Worker l.append(x_1) 13484*da0073e9SAndroid Build Coastguard Worker x_view = l[0] 13485*da0073e9SAndroid Build Coastguard Worker x.add_(torch.ones(2, 3)) 13486*da0073e9SAndroid Build Coastguard Worker return x_view 13487*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, ()) 13488*da0073e9SAndroid Build Coastguard Worker 13489*da0073e9SAndroid Build Coastguard Worker def test_mutable_dce_indirect_wildcard_write(self): 13490*da0073e9SAndroid Build Coastguard Worker def fn(): 13491*da0073e9SAndroid Build Coastguard Worker indexes = torch.jit.annotate(List[Tensor], []) 13492*da0073e9SAndroid Build Coastguard Worker word_ids = torch.zeros(10, dtype=torch.int32) 13493*da0073e9SAndroid Build Coastguard Worker word_ids[1] = 1 13494*da0073e9SAndroid Build Coastguard Worker indexes.append(word_ids) 13495*da0073e9SAndroid Build Coastguard Worker 13496*da0073e9SAndroid Build Coastguard Worker return word_ids 13497*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, ()) 13498*da0073e9SAndroid Build Coastguard Worker 13499*da0073e9SAndroid Build Coastguard Worker def test_mutable_dce_wildcards(self): 13500*da0073e9SAndroid Build Coastguard Worker def fn(): 13501*da0073e9SAndroid Build Coastguard Worker x = torch.ones(2, 3) 13502*da0073e9SAndroid Build Coastguard Worker l = [] 13503*da0073e9SAndroid Build Coastguard Worker l.append(x) 13504*da0073e9SAndroid Build Coastguard Worker x_view = l[0] 13505*da0073e9SAndroid Build Coastguard Worker x.add_(torch.ones(2, 3)) 13506*da0073e9SAndroid Build Coastguard Worker return x_view 13507*da0073e9SAndroid Build Coastguard Worker 13508*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, (), profiling=ProfilingMode.SIMPLE) 13509*da0073e9SAndroid Build Coastguard Worker 13510*da0073e9SAndroid Build Coastguard Worker def test_cpp_function_tensor_str(self): 13511*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 2) 13512*da0073e9SAndroid Build Coastguard Worker scale = torch.randn(2, 2, requires_grad=True) 13513*da0073e9SAndroid Build Coastguard Worker shift = torch.randn(2, 2, requires_grad=True) 13514*da0073e9SAndroid Build Coastguard Worker 13515*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 13516*da0073e9SAndroid Build Coastguard Worker def fn(x, scale, shift): 13517*da0073e9SAndroid Build Coastguard Worker return scale * x + shift 13518*da0073e9SAndroid Build Coastguard Worker 13519*da0073e9SAndroid Build Coastguard Worker with self.capture_stdout() as captured: 13520*da0073e9SAndroid Build Coastguard Worker print(fn(x, scale, shift)) 13521*da0073e9SAndroid Build Coastguard Worker 13522*da0073e9SAndroid Build Coastguard Worker def test_string_index(self): 13523*da0073e9SAndroid Build Coastguard Worker def fn(x): 13524*da0073e9SAndroid Build Coastguard Worker # type: (str) 13525*da0073e9SAndroid Build Coastguard Worker return x[2], x[-1] 13526*da0073e9SAndroid Build Coastguard Worker 13527*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, ("abcde",)) 13528*da0073e9SAndroid Build Coastguard Worker 13529*da0073e9SAndroid Build Coastguard Worker def test_ord(self): 13530*da0073e9SAndroid Build Coastguard Worker def fn(x): 13531*da0073e9SAndroid Build Coastguard Worker # type: (str) -> int 13532*da0073e9SAndroid Build Coastguard Worker return ord(x) 13533*da0073e9SAndroid Build Coastguard Worker 13534*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, ("h")) 13535*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, ("y")) 13536*da0073e9SAndroid Build Coastguard Worker 13537*da0073e9SAndroid Build Coastguard Worker def index_str_to_tensor(s): 13538*da0073e9SAndroid Build Coastguard Worker # type: (str) -> Tensor 13539*da0073e9SAndroid Build Coastguard Worker return torch.tensor(ord(s)) # noqa: T484 13540*da0073e9SAndroid Build Coastguard Worker 13541*da0073e9SAndroid Build Coastguard Worker s = '\u00a3'.encode()[:1] 13542*da0073e9SAndroid Build Coastguard Worker self.checkScript(index_str_to_tensor, (s,)) 13543*da0073e9SAndroid Build Coastguard Worker 13544*da0073e9SAndroid Build Coastguard Worker def test_chr(self): 13545*da0073e9SAndroid Build Coastguard Worker def fn(x): 13546*da0073e9SAndroid Build Coastguard Worker # type: (int) -> str 13547*da0073e9SAndroid Build Coastguard Worker return chr(x) 13548*da0073e9SAndroid Build Coastguard Worker 13549*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, (1,)) 13550*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, (97,)) 13551*da0073e9SAndroid Build Coastguard Worker 13552*da0073e9SAndroid Build Coastguard Worker def test_round(self): 13553*da0073e9SAndroid Build Coastguard Worker def round_float(x): 13554*da0073e9SAndroid Build Coastguard Worker # type: (float) -> float 13555*da0073e9SAndroid Build Coastguard Worker return round(x) 13556*da0073e9SAndroid Build Coastguard Worker 13557*da0073e9SAndroid Build Coastguard Worker def round_int(x): 13558*da0073e9SAndroid Build Coastguard Worker # type: (int) -> float 13559*da0073e9SAndroid Build Coastguard Worker return round(x) 13560*da0073e9SAndroid Build Coastguard Worker 13561*da0073e9SAndroid Build Coastguard Worker self.checkScript(round_float, (1.5,)) 13562*da0073e9SAndroid Build Coastguard Worker self.checkScript(round_int, (2,)) 13563*da0073e9SAndroid Build Coastguard Worker 13564*da0073e9SAndroid Build Coastguard Worker def test_convert_base(self): 13565*da0073e9SAndroid Build Coastguard Worker def test_hex(x): 13566*da0073e9SAndroid Build Coastguard Worker # type: (int) -> str 13567*da0073e9SAndroid Build Coastguard Worker return hex(x) 13568*da0073e9SAndroid Build Coastguard Worker 13569*da0073e9SAndroid Build Coastguard Worker def test_oct(x): 13570*da0073e9SAndroid Build Coastguard Worker # type: (int) -> str 13571*da0073e9SAndroid Build Coastguard Worker return oct(x) 13572*da0073e9SAndroid Build Coastguard Worker 13573*da0073e9SAndroid Build Coastguard Worker def test_bin(x): 13574*da0073e9SAndroid Build Coastguard Worker # type: (int) -> str 13575*da0073e9SAndroid Build Coastguard Worker return bin(x) 13576*da0073e9SAndroid Build Coastguard Worker 13577*da0073e9SAndroid Build Coastguard Worker numbers = [-1000, -10, 0, 1, 10, 2343] 13578*da0073e9SAndroid Build Coastguard Worker for n in numbers: 13579*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_bin, (n,)) 13580*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_oct, (n,)) 13581*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_hex, (n,)) 13582*da0073e9SAndroid Build Coastguard Worker 13583*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: TemporaryFileName support for Windows or Sandcastle") 13584*da0073e9SAndroid Build Coastguard Worker def test_get_set_state(self): 13585*da0073e9SAndroid Build Coastguard Worker class Root(torch.jit.ScriptModule): 13586*da0073e9SAndroid Build Coastguard Worker __constants__ = ['number'] 13587*da0073e9SAndroid Build Coastguard Worker 13588*da0073e9SAndroid Build Coastguard Worker def __init__(self, number): 13589*da0073e9SAndroid Build Coastguard Worker super().__init__() 13590*da0073e9SAndroid Build Coastguard Worker self.buffer1 = nn.Buffer(torch.ones(2, 2)) 13591*da0073e9SAndroid Build Coastguard Worker self.buffer2 = nn.Buffer(torch.ones(2, 2)) 13592*da0073e9SAndroid Build Coastguard Worker self.number = number 13593*da0073e9SAndroid Build Coastguard Worker 13594*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 13595*da0073e9SAndroid Build Coastguard Worker def __getstate__(self): 13596*da0073e9SAndroid Build Coastguard Worker return (self.buffer1, self.buffer2, 74, self.training) 13597*da0073e9SAndroid Build Coastguard Worker 13598*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 13599*da0073e9SAndroid Build Coastguard Worker def __setstate__(self, state): 13600*da0073e9SAndroid Build Coastguard Worker self.buffer1 = state[0] + 10 13601*da0073e9SAndroid Build Coastguard Worker self.buffer2 = state[1] + 10 13602*da0073e9SAndroid Build Coastguard Worker self.training = state[3] 13603*da0073e9SAndroid Build Coastguard Worker 13604*da0073e9SAndroid Build Coastguard Worker class M(torch.jit.ScriptModule): 13605*da0073e9SAndroid Build Coastguard Worker __constants__ = ['number'] 13606*da0073e9SAndroid Build Coastguard Worker 13607*da0073e9SAndroid Build Coastguard Worker def __init__(self, number, submodule): 13608*da0073e9SAndroid Build Coastguard Worker super().__init__() 13609*da0073e9SAndroid Build Coastguard Worker self.buffer1 = nn.Buffer(torch.ones(2, 2)) 13610*da0073e9SAndroid Build Coastguard Worker self.buffer2 = nn.Buffer(torch.ones(2, 2)) 13611*da0073e9SAndroid Build Coastguard Worker self.number = number 13612*da0073e9SAndroid Build Coastguard Worker self.submodule = submodule 13613*da0073e9SAndroid Build Coastguard Worker 13614*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 13615*da0073e9SAndroid Build Coastguard Worker def __getstate__(self): 13616*da0073e9SAndroid Build Coastguard Worker return (self.buffer1, self.buffer2, 74, self.submodule, self.training) 13617*da0073e9SAndroid Build Coastguard Worker 13618*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 13619*da0073e9SAndroid Build Coastguard Worker def __setstate__(self, state): 13620*da0073e9SAndroid Build Coastguard Worker self.buffer1 = state[0] + 10 13621*da0073e9SAndroid Build Coastguard Worker self.buffer2 = state[1] + 10 13622*da0073e9SAndroid Build Coastguard Worker self.submodule = state[3] 13623*da0073e9SAndroid Build Coastguard Worker self.training = state[4] 13624*da0073e9SAndroid Build Coastguard Worker 13625*da0073e9SAndroid Build Coastguard Worker with TemporaryFileName() as fname: 13626*da0073e9SAndroid Build Coastguard Worker m = M(23, submodule=Root(99)) 13627*da0073e9SAndroid Build Coastguard Worker m.save(fname) 13628*da0073e9SAndroid Build Coastguard Worker loaded = torch.jit.load(fname) 13629*da0073e9SAndroid Build Coastguard Worker 13630*da0073e9SAndroid Build Coastguard Worker # Check original module 13631*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m.buffer1, torch.ones(2, 2)) 13632*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m.buffer2, torch.ones(2, 2)) 13633*da0073e9SAndroid Build Coastguard Worker 13634*da0073e9SAndroid Build Coastguard Worker # Check top level module 13635*da0073e9SAndroid Build Coastguard Worker self.assertEqual(loaded.buffer1, torch.ones(2, 2) + 10) 13636*da0073e9SAndroid Build Coastguard Worker self.assertEqual(loaded.buffer2, torch.ones(2, 2) + 10) 13637*da0073e9SAndroid Build Coastguard Worker 13638*da0073e9SAndroid Build Coastguard Worker # Check submodule 13639*da0073e9SAndroid Build Coastguard Worker self.assertEqual(loaded.submodule.buffer1, torch.ones(2, 2) + 10) 13640*da0073e9SAndroid Build Coastguard Worker self.assertEqual(loaded.submodule.buffer2, torch.ones(2, 2) + 10) 13641*da0073e9SAndroid Build Coastguard Worker 13642*da0073e9SAndroid Build Coastguard Worker # Check simpler module 13643*da0073e9SAndroid Build Coastguard Worker class NoArgState(torch.nn.Module): 13644*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 13645*da0073e9SAndroid Build Coastguard Worker super().__init__() 13646*da0073e9SAndroid Build Coastguard Worker self.buffer1 = nn.Buffer(torch.ones(2, 2)) 13647*da0073e9SAndroid Build Coastguard Worker self.buffer2 = nn.Buffer(torch.ones(2, 2)) 13648*da0073e9SAndroid Build Coastguard Worker 13649*da0073e9SAndroid Build Coastguard Worker def forward(self): 13650*da0073e9SAndroid Build Coastguard Worker pass 13651*da0073e9SAndroid Build Coastguard Worker 13652*da0073e9SAndroid Build Coastguard Worker @torch.jit.export 13653*da0073e9SAndroid Build Coastguard Worker def __getstate__(self): 13654*da0073e9SAndroid Build Coastguard Worker return 5, self.training 13655*da0073e9SAndroid Build Coastguard Worker 13656*da0073e9SAndroid Build Coastguard Worker @torch.jit.export 13657*da0073e9SAndroid Build Coastguard Worker def __setstate__(self, state): 13658*da0073e9SAndroid Build Coastguard Worker self.buffer1 = torch.ones(2, 2) + state[0] 13659*da0073e9SAndroid Build Coastguard Worker self.buffer2 = torch.ones(2, 2) + 10 13660*da0073e9SAndroid Build Coastguard Worker self.training = state[1] 13661*da0073e9SAndroid Build Coastguard Worker 13662*da0073e9SAndroid Build Coastguard Worker with TemporaryFileName() as fname: 13663*da0073e9SAndroid Build Coastguard Worker m = torch.jit.script(NoArgState()) 13664*da0073e9SAndroid Build Coastguard Worker m.save(fname) 13665*da0073e9SAndroid Build Coastguard Worker loaded = torch.jit.load(fname) 13666*da0073e9SAndroid Build Coastguard Worker self.assertEqual(loaded.buffer1, torch.ones(2, 2) + 5) 13667*da0073e9SAndroid Build Coastguard Worker self.assertEqual(loaded.buffer2, torch.ones(2, 2) + 10) 13668*da0073e9SAndroid Build Coastguard Worker 13669*da0073e9SAndroid Build Coastguard Worker 13670*da0073e9SAndroid Build Coastguard Worker 13671*da0073e9SAndroid Build Coastguard Worker def test_string_slicing(self): 13672*da0073e9SAndroid Build Coastguard Worker def fn1(x): 13673*da0073e9SAndroid Build Coastguard Worker # type: (str) -> str 13674*da0073e9SAndroid Build Coastguard Worker return x[1:3] 13675*da0073e9SAndroid Build Coastguard Worker 13676*da0073e9SAndroid Build Coastguard Worker def fn2(x): 13677*da0073e9SAndroid Build Coastguard Worker # type: (str) -> str 13678*da0073e9SAndroid Build Coastguard Worker return x[-1:3] 13679*da0073e9SAndroid Build Coastguard Worker 13680*da0073e9SAndroid Build Coastguard Worker def fn3(x): 13681*da0073e9SAndroid Build Coastguard Worker # type: (str) -> str 13682*da0073e9SAndroid Build Coastguard Worker return x[3:1] 13683*da0073e9SAndroid Build Coastguard Worker 13684*da0073e9SAndroid Build Coastguard Worker def fn4(x): 13685*da0073e9SAndroid Build Coastguard Worker # type: (str) -> str 13686*da0073e9SAndroid Build Coastguard Worker return x[3:100] 13687*da0073e9SAndroid Build Coastguard Worker 13688*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn1, ("abcdefghi",)) 13689*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn2, ("abcdefghi",)) 13690*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn3, ("abcdefghi",)) 13691*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn4, ("abcdefghi",)) 13692*da0073e9SAndroid Build Coastguard Worker 13693*da0073e9SAndroid Build Coastguard Worker def test_early_return_closure(self): 13694*da0073e9SAndroid Build Coastguard Worker code = dedent(''' 13695*da0073e9SAndroid Build Coastguard Worker def tanh(self): 13696*da0073e9SAndroid Build Coastguard Worker output = torch.tanh(self) 13697*da0073e9SAndroid Build Coastguard Worker def backward(grad_output): 13698*da0073e9SAndroid Build Coastguard Worker pass 13699*da0073e9SAndroid Build Coastguard Worker return output, backward 13700*da0073e9SAndroid Build Coastguard Worker ''') 13701*da0073e9SAndroid Build Coastguard Worker cu = torch.jit.CompilationUnit(code) 13702*da0073e9SAndroid Build Coastguard Worker g = cu.tanh.graph 13703*da0073e9SAndroid Build Coastguard Worker FileCheck().check_count("prim::Closure_0", 2).check("NoneType = prim::Constant") \ 13704*da0073e9SAndroid Build Coastguard Worker .check_next("return").run(g) 13705*da0073e9SAndroid Build Coastguard Worker 13706*da0073e9SAndroid Build Coastguard Worker code = dedent(''' 13707*da0073e9SAndroid Build Coastguard Worker def tanh(self): 13708*da0073e9SAndroid Build Coastguard Worker output = torch.tanh(self) 13709*da0073e9SAndroid Build Coastguard Worker def backward(grad_output): 13710*da0073e9SAndroid Build Coastguard Worker a = 1 13711*da0073e9SAndroid Build Coastguard Worker if output: 13712*da0073e9SAndroid Build Coastguard Worker return 1 13713*da0073e9SAndroid Build Coastguard Worker else: 13714*da0073e9SAndroid Build Coastguard Worker a = 2 13715*da0073e9SAndroid Build Coastguard Worker return a 13716*da0073e9SAndroid Build Coastguard Worker return output, backward 13717*da0073e9SAndroid Build Coastguard Worker ''') 13718*da0073e9SAndroid Build Coastguard Worker cu = torch.jit.CompilationUnit(code) 13719*da0073e9SAndroid Build Coastguard Worker g = cu.tanh.graph 13720*da0073e9SAndroid Build Coastguard Worker FileCheck().check_count("prim::Closure_0", 2).check("int = prim::If") \ 13721*da0073e9SAndroid Build Coastguard Worker .run(g) 13722*da0073e9SAndroid Build Coastguard Worker 13723*da0073e9SAndroid Build Coastguard Worker code = dedent(''' 13724*da0073e9SAndroid Build Coastguard Worker def loop_in_closure(self): 13725*da0073e9SAndroid Build Coastguard Worker output = torch.tanh(self) 13726*da0073e9SAndroid Build Coastguard Worker def backward(grad_output): 13727*da0073e9SAndroid Build Coastguard Worker for i in range(3): 13728*da0073e9SAndroid Build Coastguard Worker return 1 13729*da0073e9SAndroid Build Coastguard Worker return 4 13730*da0073e9SAndroid Build Coastguard Worker return output, backward 13731*da0073e9SAndroid Build Coastguard Worker ''') 13732*da0073e9SAndroid Build Coastguard Worker cu = torch.jit.CompilationUnit(code) 13733*da0073e9SAndroid Build Coastguard Worker fc = FileCheck() 13734*da0073e9SAndroid Build Coastguard Worker fc.check("prim::Closure").check("(Tensor, NoneType) = prim::TupleConstruct") 13735*da0073e9SAndroid Build Coastguard Worker # Loop then two if's added in exit transform 13736*da0073e9SAndroid Build Coastguard Worker fc.check("prim::Closure").check("prim::Loop").check_count("prim::If", 2) 13737*da0073e9SAndroid Build Coastguard Worker fc.run(cu.loop_in_closure.graph) 13738*da0073e9SAndroid Build Coastguard Worker 13739*da0073e9SAndroid Build Coastguard Worker code = dedent(''' 13740*da0073e9SAndroid Build Coastguard Worker def tanh(self): 13741*da0073e9SAndroid Build Coastguard Worker output = torch.tanh(self) 13742*da0073e9SAndroid Build Coastguard Worker def backward(grad_output): 13743*da0073e9SAndroid Build Coastguard Worker if 1 == 1: 13744*da0073e9SAndroid Build Coastguard Worker return 1 13745*da0073e9SAndroid Build Coastguard Worker else: 13746*da0073e9SAndroid Build Coastguard Worker return 1. 13747*da0073e9SAndroid Build Coastguard Worker return output, backward 13748*da0073e9SAndroid Build Coastguard Worker ''') 13749*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "returned a value of type int but"): 13750*da0073e9SAndroid Build Coastguard Worker cu = torch.jit.CompilationUnit(code) 13751*da0073e9SAndroid Build Coastguard Worker 13752*da0073e9SAndroid Build Coastguard Worker @_inline_everything 13753*da0073e9SAndroid Build Coastguard Worker def test_early_return_fork_join(self): 13754*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 13755*da0073e9SAndroid Build Coastguard Worker def foo(x): 13756*da0073e9SAndroid Build Coastguard Worker if x.dim() == 2: 13757*da0073e9SAndroid Build Coastguard Worker return torch.neg(x), x 13758*da0073e9SAndroid Build Coastguard Worker else: 13759*da0073e9SAndroid Build Coastguard Worker return torch.neg(x), x + 1 13760*da0073e9SAndroid Build Coastguard Worker 13761*da0073e9SAndroid Build Coastguard Worker x = torch.rand(3, 4) 13762*da0073e9SAndroid Build Coastguard Worker 13763*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 13764*da0073e9SAndroid Build Coastguard Worker def wait_script(x): 13765*da0073e9SAndroid Build Coastguard Worker fut = torch.jit._fork(foo, x) 13766*da0073e9SAndroid Build Coastguard Worker y_hat = foo(x) 13767*da0073e9SAndroid Build Coastguard Worker y = torch.jit._wait(fut) 13768*da0073e9SAndroid Build Coastguard Worker return y, y_hat 13769*da0073e9SAndroid Build Coastguard Worker 13770*da0073e9SAndroid Build Coastguard Worker FileCheck().check("with prim::fork").check("prim::If").check("return")\ 13771*da0073e9SAndroid Build Coastguard Worker .run(wait_script.graph) 13772*da0073e9SAndroid Build Coastguard Worker 13773*da0073e9SAndroid Build Coastguard Worker def test_early_return_type_refinement(self): 13774*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 13775*da0073e9SAndroid Build Coastguard Worker def test(x): 13776*da0073e9SAndroid Build Coastguard Worker # type: (Optional[int]) -> int 13777*da0073e9SAndroid Build Coastguard Worker if x is None: 13778*da0073e9SAndroid Build Coastguard Worker return 1 13779*da0073e9SAndroid Build Coastguard Worker else: 13780*da0073e9SAndroid Build Coastguard Worker return x 13781*da0073e9SAndroid Build Coastguard Worker self.assertEqual(test(None), 1) 13782*da0073e9SAndroid Build Coastguard Worker self.assertEqual(test(2), 2) 13783*da0073e9SAndroid Build Coastguard Worker 13784*da0073e9SAndroid Build Coastguard Worker def test_exceptions_with_control_flow(self): 13785*da0073e9SAndroid Build Coastguard Worker def test_num_ifs(func, num_ifs): 13786*da0073e9SAndroid Build Coastguard Worker g = torch.jit.script(func).graph 13787*da0073e9SAndroid Build Coastguard Worker FileCheck().check_count("prim::If", num_ifs, exactly=True).run(g) 13788*da0073e9SAndroid Build Coastguard Worker 13789*da0073e9SAndroid Build Coastguard Worker def no_guard_ifs_added(x): 13790*da0073e9SAndroid Build Coastguard Worker # type: (int) -> int 13791*da0073e9SAndroid Build Coastguard Worker if x == 1: 13792*da0073e9SAndroid Build Coastguard Worker return 1 13793*da0073e9SAndroid Build Coastguard Worker else: 13794*da0073e9SAndroid Build Coastguard Worker if x == 2: 13795*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("hi") 13796*da0073e9SAndroid Build Coastguard Worker else: 13797*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("hi") 13798*da0073e9SAndroid Build Coastguard Worker 13799*da0073e9SAndroid Build Coastguard Worker self.checkScript(no_guard_ifs_added, (1,)) 13800*da0073e9SAndroid Build Coastguard Worker self.checkScriptRaisesRegex(no_guard_ifs_added, (2,), Exception, "") 13801*da0073e9SAndroid Build Coastguard Worker test_num_ifs(no_guard_ifs_added, 2) 13802*da0073e9SAndroid Build Coastguard Worker 13803*da0073e9SAndroid Build Coastguard Worker # FUNCTION LOOKS LIKE: 13804*da0073e9SAndroid Build Coastguard Worker # graph(%x.1 : int): 13805*da0073e9SAndroid Build Coastguard Worker # %7 : str = prim::Constant[value="Exception"]() 13806*da0073e9SAndroid Build Coastguard Worker # %2 : int = prim::Constant[value=1]() 13807*da0073e9SAndroid Build Coastguard Worker # %5 : int = prim::Constant[value=2]() 13808*da0073e9SAndroid Build Coastguard Worker # %19 : int = prim::Uninitialized() 13809*da0073e9SAndroid Build Coastguard Worker # %3 : bool = aten::eq(%x.1, %2) 13810*da0073e9SAndroid Build Coastguard Worker # %20 : int = prim::If(%3) 13811*da0073e9SAndroid Build Coastguard Worker # block0(): 13812*da0073e9SAndroid Build Coastguard Worker # -> (%2) 13813*da0073e9SAndroid Build Coastguard Worker # block1(): 13814*da0073e9SAndroid Build Coastguard Worker # %6 : bool = aten::eq(%x.1, %5) 13815*da0073e9SAndroid Build Coastguard Worker # = prim::If(%6) 13816*da0073e9SAndroid Build Coastguard Worker # block0(): 13817*da0073e9SAndroid Build Coastguard Worker # = prim::RaiseException(%7) 13818*da0073e9SAndroid Build Coastguard Worker # -> () 13819*da0073e9SAndroid Build Coastguard Worker # block1(): 13820*da0073e9SAndroid Build Coastguard Worker # = prim::RaiseException(%7) 13821*da0073e9SAndroid Build Coastguard Worker # -> () 13822*da0073e9SAndroid Build Coastguard Worker # -> (%19) 13823*da0073e9SAndroid Build Coastguard Worker # return (%20) 13824*da0073e9SAndroid Build Coastguard Worker 13825*da0073e9SAndroid Build Coastguard Worker def no_ifs_added(x): 13826*da0073e9SAndroid Build Coastguard Worker # type: (int) -> int 13827*da0073e9SAndroid Build Coastguard Worker if x < 0: 13828*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("hi") 13829*da0073e9SAndroid Build Coastguard Worker return x 13830*da0073e9SAndroid Build Coastguard Worker 13831*da0073e9SAndroid Build Coastguard Worker self.checkScript(no_ifs_added, (1,)) 13832*da0073e9SAndroid Build Coastguard Worker self.checkScriptRaisesRegex(no_ifs_added, (-2,), Exception, "") 13833*da0073e9SAndroid Build Coastguard Worker test_num_ifs(no_ifs_added, 1) 13834*da0073e9SAndroid Build Coastguard Worker 13835*da0073e9SAndroid Build Coastguard Worker def test_if_might(x): 13836*da0073e9SAndroid Build Coastguard Worker # type: (int) 13837*da0073e9SAndroid Build Coastguard Worker if x > 0: 13838*da0073e9SAndroid Build Coastguard Worker if x == 1: 13839*da0073e9SAndroid Build Coastguard Worker return 1 13840*da0073e9SAndroid Build Coastguard Worker else: 13841*da0073e9SAndroid Build Coastguard Worker a = 2 13842*da0073e9SAndroid Build Coastguard Worker else: 13843*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("hi") 13844*da0073e9SAndroid Build Coastguard Worker return a + 2 13845*da0073e9SAndroid Build Coastguard Worker 13846*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_if_might, (1,)) 13847*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_if_might, (3,)) 13848*da0073e9SAndroid Build Coastguard Worker self.checkScriptRaisesRegex(no_ifs_added, (-2,), Exception, "") 13849*da0073e9SAndroid Build Coastguard Worker test_num_ifs(test_if_might, 3) # one if added to guard a + 2 13850*da0073e9SAndroid Build Coastguard Worker 13851*da0073e9SAndroid Build Coastguard Worker def test_loop_no_escape(x): 13852*da0073e9SAndroid Build Coastguard Worker # type: (int) 13853*da0073e9SAndroid Build Coastguard Worker if x >= 0: 13854*da0073e9SAndroid Build Coastguard Worker for i in range(x): 13855*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("hi") 13856*da0073e9SAndroid Build Coastguard Worker else: 13857*da0073e9SAndroid Build Coastguard Worker return 5 13858*da0073e9SAndroid Build Coastguard Worker return x + 3 13859*da0073e9SAndroid Build Coastguard Worker 13860*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_loop_no_escape, (0,)) 13861*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_loop_no_escape, (-1,)) 13862*da0073e9SAndroid Build Coastguard Worker self.checkScriptRaisesRegex(test_loop_no_escape, (1,), Exception, "") 13863*da0073e9SAndroid Build Coastguard Worker 13864*da0073e9SAndroid Build Coastguard Worker # if guard gets optimized away 13865*da0073e9SAndroid Build Coastguard Worker test_num_ifs(test_loop_no_escape, 1) 13866*da0073e9SAndroid Build Coastguard Worker 13867*da0073e9SAndroid Build Coastguard Worker def test_loop_exception_with_continue(x): 13868*da0073e9SAndroid Build Coastguard Worker # type: (int) 13869*da0073e9SAndroid Build Coastguard Worker i = 0 13870*da0073e9SAndroid Build Coastguard Worker for i in range(5): 13871*da0073e9SAndroid Build Coastguard Worker if i == x: 13872*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("hi") 13873*da0073e9SAndroid Build Coastguard Worker else: 13874*da0073e9SAndroid Build Coastguard Worker continue 13875*da0073e9SAndroid Build Coastguard Worker print(i) 13876*da0073e9SAndroid Build Coastguard Worker return i + 5 13877*da0073e9SAndroid Build Coastguard Worker 13878*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_loop_exception_with_continue, (-1,)) 13879*da0073e9SAndroid Build Coastguard Worker self.checkScriptRaisesRegex(test_loop_exception_with_continue, (1,), Exception, "") 13880*da0073e9SAndroid Build Coastguard Worker test_num_ifs(test_loop_exception_with_continue, 1) # no ifs added to guard print 13881*da0073e9SAndroid Build Coastguard Worker 13882*da0073e9SAndroid Build Coastguard Worker 13883*da0073e9SAndroid Build Coastguard Worker def test_exception_exits_closure(self): 13884*da0073e9SAndroid Build Coastguard Worker code = dedent(''' 13885*da0073e9SAndroid Build Coastguard Worker def no_return_func(self): 13886*da0073e9SAndroid Build Coastguard Worker # type: (Tensor) -> Tensor 13887*da0073e9SAndroid Build Coastguard Worker output = torch.tanh(self) 13888*da0073e9SAndroid Build Coastguard Worker def backward(grad_output): 13889*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("Hi") 13890*da0073e9SAndroid Build Coastguard Worker ''') 13891*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "does not return along all"): 13892*da0073e9SAndroid Build Coastguard Worker cu = torch.jit.CompilationUnit(code) 13893*da0073e9SAndroid Build Coastguard Worker 13894*da0073e9SAndroid Build Coastguard Worker code = dedent(''' 13895*da0073e9SAndroid Build Coastguard Worker def test_exit_pair_reset(x): 13896*da0073e9SAndroid Build Coastguard Worker # type: (int) -> int 13897*da0073e9SAndroid Build Coastguard Worker if x > 0: 13898*da0073e9SAndroid Build Coastguard Worker a = 0 13899*da0073e9SAndroid Build Coastguard Worker def backward(grad_output): 13900*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("Hi") 13901*da0073e9SAndroid Build Coastguard Worker a = a + 1 13902*da0073e9SAndroid Build Coastguard Worker else: 13903*da0073e9SAndroid Build Coastguard Worker return x 13904*da0073e9SAndroid Build Coastguard Worker return a + 1 13905*da0073e9SAndroid Build Coastguard Worker ''') 13906*da0073e9SAndroid Build Coastguard Worker func = torch.jit.CompilationUnit(code).test_exit_pair_reset 13907*da0073e9SAndroid Build Coastguard Worker self.assertEqual(func(1,), 2) 13908*da0073e9SAndroid Build Coastguard Worker self.assertEqual(func(-1,), -1) 13909*da0073e9SAndroid Build Coastguard Worker # final a + 1 gets inlined into the first branch and optimized away 13910*da0073e9SAndroid Build Coastguard Worker FileCheck().check_count("prim::If", 1, exactly=True).run(func.graph) 13911*da0073e9SAndroid Build Coastguard Worker 13912*da0073e9SAndroid Build Coastguard Worker def test_non_final_return(self): 13913*da0073e9SAndroid Build Coastguard Worker def simple(x): 13914*da0073e9SAndroid Build Coastguard Worker if bool(x > 3): 13915*da0073e9SAndroid Build Coastguard Worker return x + 1 13916*da0073e9SAndroid Build Coastguard Worker else: 13917*da0073e9SAndroid Build Coastguard Worker return x + 2 13918*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("nope") 13919*da0073e9SAndroid Build Coastguard Worker 13920*da0073e9SAndroid Build Coastguard Worker def nest(x): 13921*da0073e9SAndroid Build Coastguard Worker x = x + 1 13922*da0073e9SAndroid Build Coastguard Worker if bool(x > 3): 13923*da0073e9SAndroid Build Coastguard Worker if bool(x > 4): 13924*da0073e9SAndroid Build Coastguard Worker x += 1 13925*da0073e9SAndroid Build Coastguard Worker return x + 1 13926*da0073e9SAndroid Build Coastguard Worker else: 13927*da0073e9SAndroid Build Coastguard Worker return x + 2 13928*da0073e9SAndroid Build Coastguard Worker 13929*da0073e9SAndroid Build Coastguard Worker def early_ret(x): 13930*da0073e9SAndroid Build Coastguard Worker x = x + 1 13931*da0073e9SAndroid Build Coastguard Worker if bool(x > 3): 13932*da0073e9SAndroid Build Coastguard Worker return x + 1 13933*da0073e9SAndroid Build Coastguard Worker x = x + 1 13934*da0073e9SAndroid Build Coastguard Worker return x + 2 13935*da0073e9SAndroid Build Coastguard Worker 13936*da0073e9SAndroid Build Coastguard Worker def nest_early_ret(x): 13937*da0073e9SAndroid Build Coastguard Worker x = x + 1 13938*da0073e9SAndroid Build Coastguard Worker if bool(x > 3): 13939*da0073e9SAndroid Build Coastguard Worker if bool(x > 4): 13940*da0073e9SAndroid Build Coastguard Worker return x + 2 13941*da0073e9SAndroid Build Coastguard Worker return x + 1 13942*da0073e9SAndroid Build Coastguard Worker x = x + 1 13943*da0073e9SAndroid Build Coastguard Worker return x + 2 13944*da0073e9SAndroid Build Coastguard Worker 13945*da0073e9SAndroid Build Coastguard Worker def not_early_ret(x): 13946*da0073e9SAndroid Build Coastguard Worker s = "" 13947*da0073e9SAndroid Build Coastguard Worker if bool(x > 3): 13948*da0073e9SAndroid Build Coastguard Worker if bool(x > 4): 13949*da0073e9SAndroid Build Coastguard Worker return 1, s 13950*da0073e9SAndroid Build Coastguard Worker s += "foo" 13951*da0073e9SAndroid Build Coastguard Worker else: 13952*da0073e9SAndroid Build Coastguard Worker s += "5" 13953*da0073e9SAndroid Build Coastguard Worker s += "hi" 13954*da0073e9SAndroid Build Coastguard Worker return 7, s 13955*da0073e9SAndroid Build Coastguard Worker 13956*da0073e9SAndroid Build Coastguard Worker def not_total_ret(x): 13957*da0073e9SAndroid Build Coastguard Worker s = "" 13958*da0073e9SAndroid Build Coastguard Worker if bool(x > 3): 13959*da0073e9SAndroid Build Coastguard Worker if bool(x > 4): 13960*da0073e9SAndroid Build Coastguard Worker return 1, s 13961*da0073e9SAndroid Build Coastguard Worker else: 13962*da0073e9SAndroid Build Coastguard Worker return 2, s 13963*da0073e9SAndroid Build Coastguard Worker else: 13964*da0073e9SAndroid Build Coastguard Worker s += "5" 13965*da0073e9SAndroid Build Coastguard Worker return 7, s 13966*da0073e9SAndroid Build Coastguard Worker 13967*da0073e9SAndroid Build Coastguard Worker for i in range(3): 13968*da0073e9SAndroid Build Coastguard Worker for func in [simple, nest, early_ret, nest_early_ret, not_early_ret, 13969*da0073e9SAndroid Build Coastguard Worker not_total_ret]: 13970*da0073e9SAndroid Build Coastguard Worker self.checkScript(func, (torch.tensor(2.5 + i),)) 13971*da0073e9SAndroid Build Coastguard Worker 13972*da0073e9SAndroid Build Coastguard Worker def vars_used_after_ret(x): 13973*da0073e9SAndroid Build Coastguard Worker # type: (int) -> int 13974*da0073e9SAndroid Build Coastguard Worker if x == 0: 13975*da0073e9SAndroid Build Coastguard Worker return x 13976*da0073e9SAndroid Build Coastguard Worker else: 13977*da0073e9SAndroid Build Coastguard Worker y = 2 13978*da0073e9SAndroid Build Coastguard Worker z = 3 13979*da0073e9SAndroid Build Coastguard Worker return x + y * z 13980*da0073e9SAndroid Build Coastguard Worker 13981*da0073e9SAndroid Build Coastguard Worker self.checkScript(vars_used_after_ret, (1,)) 13982*da0073e9SAndroid Build Coastguard Worker self.checkScript(vars_used_after_ret, (0,)) 13983*da0073e9SAndroid Build Coastguard Worker 13984*da0073e9SAndroid Build Coastguard Worker def complicated(x): 13985*da0073e9SAndroid Build Coastguard Worker # type: (int) -> int 13986*da0073e9SAndroid Build Coastguard Worker if x: 13987*da0073e9SAndroid Build Coastguard Worker if x == 2: 13988*da0073e9SAndroid Build Coastguard Worker return 1 13989*da0073e9SAndroid Build Coastguard Worker assert 1 == 2 13990*da0073e9SAndroid Build Coastguard Worker else: 13991*da0073e9SAndroid Build Coastguard Worker if x == 3: 13992*da0073e9SAndroid Build Coastguard Worker return 2 13993*da0073e9SAndroid Build Coastguard Worker assert 1 == 2 13994*da0073e9SAndroid Build Coastguard Worker else: 13995*da0073e9SAndroid Build Coastguard Worker a = 2 13996*da0073e9SAndroid Build Coastguard Worker b = 3 13997*da0073e9SAndroid Build Coastguard Worker else: 13998*da0073e9SAndroid Build Coastguard Worker a = 4 13999*da0073e9SAndroid Build Coastguard Worker b = 1 14000*da0073e9SAndroid Build Coastguard Worker return a + b 14001*da0073e9SAndroid Build Coastguard Worker assert 1 == 2 14002*da0073e9SAndroid Build Coastguard Worker 14003*da0073e9SAndroid Build Coastguard Worker for i in range(4): 14004*da0073e9SAndroid Build Coastguard Worker self.checkScript(complicated, (i,)) 14005*da0073e9SAndroid Build Coastguard Worker 14006*da0073e9SAndroid Build Coastguard Worker def test_partial_returns(self): 14007*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "does not return along all"): 14008*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 14009*da0073e9SAndroid Build Coastguard Worker def no_ret(): 14010*da0073e9SAndroid Build Coastguard Worker # type: () -> int 14011*da0073e9SAndroid Build Coastguard Worker pass 14012*da0073e9SAndroid Build Coastguard Worker 14013*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "does not return along all"): 14014*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 14015*da0073e9SAndroid Build Coastguard Worker def partial(x): 14016*da0073e9SAndroid Build Coastguard Worker # type: (Tensor) -> int 14017*da0073e9SAndroid Build Coastguard Worker if x: 14018*da0073e9SAndroid Build Coastguard Worker return 1 14019*da0073e9SAndroid Build Coastguard Worker 14020*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "does not return along all"): 14021*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 14022*da0073e9SAndroid Build Coastguard Worker def typed_none(): 14023*da0073e9SAndroid Build Coastguard Worker # type: () -> Optional[int] 14024*da0073e9SAndroid Build Coastguard Worker pass 14025*da0073e9SAndroid Build Coastguard Worker 14026*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 14027*da0073e9SAndroid Build Coastguard Worker def none_ret(): 14028*da0073e9SAndroid Build Coastguard Worker pass 14029*da0073e9SAndroid Build Coastguard Worker 14030*da0073e9SAndroid Build Coastguard Worker self.assertIs(none_ret(), None) 14031*da0073e9SAndroid Build Coastguard Worker FileCheck().check(": None").run(none_ret.graph) 14032*da0073e9SAndroid Build Coastguard Worker 14033*da0073e9SAndroid Build Coastguard Worker def test_early_returns_loops(self): 14034*da0073e9SAndroid Build Coastguard Worker def nest_while_ret(x): 14035*da0073e9SAndroid Build Coastguard Worker # type: (int) -> int 14036*da0073e9SAndroid Build Coastguard Worker y = 4 14037*da0073e9SAndroid Build Coastguard Worker while x < 4: 14038*da0073e9SAndroid Build Coastguard Worker if x < 3: 14039*da0073e9SAndroid Build Coastguard Worker return y 14040*da0073e9SAndroid Build Coastguard Worker else: 14041*da0073e9SAndroid Build Coastguard Worker y = y + 1 14042*da0073e9SAndroid Build Coastguard Worker break 14043*da0073e9SAndroid Build Coastguard Worker y = y + 2 14044*da0073e9SAndroid Build Coastguard Worker y = y + 1 14045*da0073e9SAndroid Build Coastguard Worker return y 14046*da0073e9SAndroid Build Coastguard Worker 14047*da0073e9SAndroid Build Coastguard Worker self.checkScript(nest_while_ret, (2,)) 14048*da0073e9SAndroid Build Coastguard Worker self.checkScript(nest_while_ret, (3,)) 14049*da0073e9SAndroid Build Coastguard Worker self.checkScript(nest_while_ret, (4,)) 14050*da0073e9SAndroid Build Coastguard Worker 14051*da0073e9SAndroid Build Coastguard Worker def loop_ret(x, y): 14052*da0073e9SAndroid Build Coastguard Worker # type: (int, int) -> (int) 14053*da0073e9SAndroid Build Coastguard Worker i = 0 14054*da0073e9SAndroid Build Coastguard Worker for i in range(x): 14055*da0073e9SAndroid Build Coastguard Worker if x == y: 14056*da0073e9SAndroid Build Coastguard Worker return x + y 14057*da0073e9SAndroid Build Coastguard Worker i = i + y 14058*da0073e9SAndroid Build Coastguard Worker i = i - 1 14059*da0073e9SAndroid Build Coastguard Worker return i 14060*da0073e9SAndroid Build Coastguard Worker 14061*da0073e9SAndroid Build Coastguard Worker self.checkScript(loop_ret, (3, 3)) 14062*da0073e9SAndroid Build Coastguard Worker self.checkScript(loop_ret, (2, 3)) 14063*da0073e9SAndroid Build Coastguard Worker self.checkScript(loop_ret, (3, 1)) 14064*da0073e9SAndroid Build Coastguard Worker 14065*da0073e9SAndroid Build Coastguard Worker def test_will_ret(y): 14066*da0073e9SAndroid Build Coastguard Worker # type: (int) -> int 14067*da0073e9SAndroid Build Coastguard Worker for i in range(y): 14068*da0073e9SAndroid Build Coastguard Worker return 2 14069*da0073e9SAndroid Build Coastguard Worker return 1 14070*da0073e9SAndroid Build Coastguard Worker 14071*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_will_ret, (0,)) 14072*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_will_ret, (1,)) 14073*da0073e9SAndroid Build Coastguard Worker 14074*da0073e9SAndroid Build Coastguard Worker def test_loop_nest_ret(y): 14075*da0073e9SAndroid Build Coastguard Worker # type: (int) -> int 14076*da0073e9SAndroid Build Coastguard Worker for i in range(y): 14077*da0073e9SAndroid Build Coastguard Worker for i in range(y - 2): 14078*da0073e9SAndroid Build Coastguard Worker return 10 14079*da0073e9SAndroid Build Coastguard Worker return 5 14080*da0073e9SAndroid Build Coastguard Worker return 0 14081*da0073e9SAndroid Build Coastguard Worker 14082*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_loop_nest_ret, (0,)) 14083*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_loop_nest_ret, (1,)) 14084*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_loop_nest_ret, (2,)) 14085*da0073e9SAndroid Build Coastguard Worker 14086*da0073e9SAndroid Build Coastguard Worker def test_nn_init(self): 14087*da0073e9SAndroid Build Coastguard Worker tests = ( 14088*da0073e9SAndroid Build Coastguard Worker ('constant_', (lambda: (torch.ones(2, 2), 2.5)), "Tensor, float"), 14089*da0073e9SAndroid Build Coastguard Worker ('ones_', (lambda: (torch.ones(2, 2),)), "Tensor"), 14090*da0073e9SAndroid Build Coastguard Worker ('zeros_', (lambda: (torch.ones(2, 2),)), "Tensor"), 14091*da0073e9SAndroid Build Coastguard Worker ('uniform_', (lambda: (torch.ones(2, 2),)), "Tensor"), 14092*da0073e9SAndroid Build Coastguard Worker ('normal_', (lambda: (torch.ones(2, 2),)), "Tensor"), 14093*da0073e9SAndroid Build Coastguard Worker ('xavier_normal_', (lambda: (torch.ones(2, 2),)), "Tensor"), 14094*da0073e9SAndroid Build Coastguard Worker ('xavier_uniform_', (lambda: (torch.ones(2, 2),)), "Tensor"), 14095*da0073e9SAndroid Build Coastguard Worker ) 14096*da0073e9SAndroid Build Coastguard Worker 14097*da0073e9SAndroid Build Coastguard Worker for name, args_fn, type_str in tests: 14098*da0073e9SAndroid Build Coastguard Worker # Build test code 14099*da0073e9SAndroid Build Coastguard Worker arg_str = ', '.join([chr(i + ord('a')) for i in range(len(args_fn()))]) 14100*da0073e9SAndroid Build Coastguard Worker 14101*da0073e9SAndroid Build Coastguard Worker code = dedent(''' 14102*da0073e9SAndroid Build Coastguard Worker def test({arg_str}): 14103*da0073e9SAndroid Build Coastguard Worker # type: ({type_str}) 14104*da0073e9SAndroid Build Coastguard Worker return torch.nn.init.{name}({arg_str}) 14105*da0073e9SAndroid Build Coastguard Worker ''').format(arg_str=arg_str, type_str=type_str, name=name) 14106*da0073e9SAndroid Build Coastguard Worker cu = torch.jit.CompilationUnit(code) 14107*da0073e9SAndroid Build Coastguard Worker 14108*da0073e9SAndroid Build Coastguard Worker # Compare functions 14109*da0073e9SAndroid Build Coastguard Worker init_fn = getattr(torch.nn.init, name) 14110*da0073e9SAndroid Build Coastguard Worker script_out = self.runAndSaveRNG(cu.test, args_fn()) 14111*da0073e9SAndroid Build Coastguard Worker eager_out = self.runAndSaveRNG(init_fn, args_fn()) 14112*da0073e9SAndroid Build Coastguard Worker self.assertEqual(script_out, eager_out) 14113*da0073e9SAndroid Build Coastguard Worker 14114*da0073e9SAndroid Build Coastguard Worker FileCheck().check_not("prim::PythonOp").run(cu.test.graph) 14115*da0073e9SAndroid Build Coastguard Worker 14116*da0073e9SAndroid Build Coastguard Worker def test_nn_init_generator(self): 14117*da0073e9SAndroid Build Coastguard Worker init_fns = ( 14118*da0073e9SAndroid Build Coastguard Worker 'uniform_', 'normal_', 'xavier_normal_', 'xavier_uniform_', 14119*da0073e9SAndroid Build Coastguard Worker ) 14120*da0073e9SAndroid Build Coastguard Worker 14121*da0073e9SAndroid Build Coastguard Worker for name in init_fns: 14122*da0073e9SAndroid Build Coastguard Worker # Build test code 14123*da0073e9SAndroid Build Coastguard Worker code = dedent(''' 14124*da0073e9SAndroid Build Coastguard Worker def test(tensor, generator): 14125*da0073e9SAndroid Build Coastguard Worker # type: (Tensor, Generator) 14126*da0073e9SAndroid Build Coastguard Worker return torch.nn.init.{name}(tensor, generator=generator) 14127*da0073e9SAndroid Build Coastguard Worker ''').format(name=name) 14128*da0073e9SAndroid Build Coastguard Worker cu = torch.jit.CompilationUnit(code) 14129*da0073e9SAndroid Build Coastguard Worker 14130*da0073e9SAndroid Build Coastguard Worker # Compare functions 14131*da0073e9SAndroid Build Coastguard Worker init_fn = getattr(torch.nn.init, name) 14132*da0073e9SAndroid Build Coastguard Worker 14133*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(1) 14134*da0073e9SAndroid Build Coastguard Worker 14135*da0073e9SAndroid Build Coastguard Worker g = torch.Generator() 14136*da0073e9SAndroid Build Coastguard Worker g.manual_seed(2023) 14137*da0073e9SAndroid Build Coastguard Worker script_out = cu.test(torch.ones(2, 2), g) 14138*da0073e9SAndroid Build Coastguard Worker 14139*da0073e9SAndroid Build Coastguard Worker # Change the seed of the default generator to make 14140*da0073e9SAndroid Build Coastguard Worker # sure that we're using the provided generator 14141*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(2) 14142*da0073e9SAndroid Build Coastguard Worker 14143*da0073e9SAndroid Build Coastguard Worker g = torch.Generator() 14144*da0073e9SAndroid Build Coastguard Worker g.manual_seed(2023) 14145*da0073e9SAndroid Build Coastguard Worker eager_out = init_fn(torch.ones(2, 2), generator=g) 14146*da0073e9SAndroid Build Coastguard Worker 14147*da0073e9SAndroid Build Coastguard Worker self.assertEqual(script_out, eager_out) 14148*da0073e9SAndroid Build Coastguard Worker 14149*da0073e9SAndroid Build Coastguard Worker FileCheck().check_not("prim::PythonOp").run(cu.test.graph) 14150*da0073e9SAndroid Build Coastguard Worker 14151*da0073e9SAndroid Build Coastguard Worker def test_early_return_rewrite(self): 14152*da0073e9SAndroid Build Coastguard Worker def test_foo(x: bool): 14153*da0073e9SAndroid Build Coastguard Worker if x: 14154*da0073e9SAndroid Build Coastguard Worker return 1 14155*da0073e9SAndroid Build Coastguard Worker return 2 14156*da0073e9SAndroid Build Coastguard Worker 14157*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_foo, (True,)) 14158*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_foo, (False,)) 14159*da0073e9SAndroid Build Coastguard Worker FileCheck().check_count("prim::If", 1, exactly=True).run(torch.jit.script(test_foo).graph) 14160*da0073e9SAndroid Build Coastguard Worker 14161*da0073e9SAndroid Build Coastguard Worker def test_multiple(x: int): 14162*da0073e9SAndroid Build Coastguard Worker if x == 5: 14163*da0073e9SAndroid Build Coastguard Worker return x * x 14164*da0073e9SAndroid Build Coastguard Worker else: 14165*da0073e9SAndroid Build Coastguard Worker y = 2 * x 14166*da0073e9SAndroid Build Coastguard Worker 14167*da0073e9SAndroid Build Coastguard Worker z = y * 2 14168*da0073e9SAndroid Build Coastguard Worker if z == 8: 14169*da0073e9SAndroid Build Coastguard Worker return 1 14170*da0073e9SAndroid Build Coastguard Worker 14171*da0073e9SAndroid Build Coastguard Worker if z != 16: 14172*da0073e9SAndroid Build Coastguard Worker z = z - 2 14173*da0073e9SAndroid Build Coastguard Worker abc = 4 14174*da0073e9SAndroid Build Coastguard Worker else: 14175*da0073e9SAndroid Build Coastguard Worker return 3 14176*da0073e9SAndroid Build Coastguard Worker 14177*da0073e9SAndroid Build Coastguard Worker z = z * abc 14178*da0073e9SAndroid Build Coastguard Worker return z * z * z 14179*da0073e9SAndroid Build Coastguard Worker 14180*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_multiple, (5,)) 14181*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_multiple, (2,)) 14182*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_multiple, (4,)) 14183*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_multiple, (3,)) 14184*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_multiple, (10,)) 14185*da0073e9SAndroid Build Coastguard Worker 14186*da0073e9SAndroid Build Coastguard Worker graph = torch.jit.script(test_multiple).graph 14187*da0073e9SAndroid Build Coastguard Worker FileCheck().check_count("prim::If", 3, exactly=True).run(graph) 14188*da0073e9SAndroid Build Coastguard Worker 14189*da0073e9SAndroid Build Coastguard Worker def test_is_scripting_metacompile(self): 14190*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 14191*da0073e9SAndroid Build Coastguard Worker def foo(): 14192*da0073e9SAndroid Build Coastguard Worker if torch.jit.is_scripting(): 14193*da0073e9SAndroid Build Coastguard Worker return 1 14194*da0073e9SAndroid Build Coastguard Worker else: 14195*da0073e9SAndroid Build Coastguard Worker print("hello") + 2 # will not be compiled 14196*da0073e9SAndroid Build Coastguard Worker 14197*da0073e9SAndroid Build Coastguard Worker self.assertEqual(foo(), 1) 14198*da0073e9SAndroid Build Coastguard Worker 14199*da0073e9SAndroid Build Coastguard Worker def test_boolean_literal_constant_metacompile(self): 14200*da0073e9SAndroid Build Coastguard Worker class Mod(torch.nn.Module): 14201*da0073e9SAndroid Build Coastguard Worker __constants__ = ['val'] 14202*da0073e9SAndroid Build Coastguard Worker 14203*da0073e9SAndroid Build Coastguard Worker def __init__(self, val): 14204*da0073e9SAndroid Build Coastguard Worker super().__init__() 14205*da0073e9SAndroid Build Coastguard Worker self.val = val 14206*da0073e9SAndroid Build Coastguard Worker 14207*da0073e9SAndroid Build Coastguard Worker def forward(self): 14208*da0073e9SAndroid Build Coastguard Worker if self.val: 14209*da0073e9SAndroid Build Coastguard Worker return 1 14210*da0073e9SAndroid Build Coastguard Worker else: 14211*da0073e9SAndroid Build Coastguard Worker return "2" 14212*da0073e9SAndroid Build Coastguard Worker 14213*da0073e9SAndroid Build Coastguard Worker self.checkModule(Mod(True), ()) 14214*da0073e9SAndroid Build Coastguard Worker self.checkModule(Mod(False), ()) 14215*da0073e9SAndroid Build Coastguard Worker 14216*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 14217*da0073e9SAndroid Build Coastguard Worker def foo(): 14218*da0073e9SAndroid Build Coastguard Worker if True: 14219*da0073e9SAndroid Build Coastguard Worker return 1 14220*da0073e9SAndroid Build Coastguard Worker else: 14221*da0073e9SAndroid Build Coastguard Worker return "2" 14222*da0073e9SAndroid Build Coastguard Worker 14223*da0073e9SAndroid Build Coastguard Worker self.assertEqual(foo(), 1) 14224*da0073e9SAndroid Build Coastguard Worker 14225*da0073e9SAndroid Build Coastguard Worker def test_assert_is_scripting_metacompile(self): 14226*da0073e9SAndroid Build Coastguard Worker def foo(): 14227*da0073e9SAndroid Build Coastguard Worker assert not torch.jit.is_scripting(), "TestErrorMsg" 14228*da0073e9SAndroid Build Coastguard Worker print("hello") + 2 # will not be compiled 14229*da0073e9SAndroid Build Coastguard Worker 14230*da0073e9SAndroid Build Coastguard Worker f = torch.jit.script(foo) 14231*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(torch.jit.Error, "TestErrorMsg"): 14232*da0073e9SAndroid Build Coastguard Worker f() 14233*da0073e9SAndroid Build Coastguard Worker 14234*da0073e9SAndroid Build Coastguard Worker def test_isinstance_metacompile(self): 14235*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 14236*da0073e9SAndroid Build Coastguard Worker def test_primitive_type(x): 14237*da0073e9SAndroid Build Coastguard Worker # type: (int) -> int 14238*da0073e9SAndroid Build Coastguard Worker if isinstance(x, int): 14239*da0073e9SAndroid Build Coastguard Worker return x + 1 14240*da0073e9SAndroid Build Coastguard Worker else: 14241*da0073e9SAndroid Build Coastguard Worker return x - 1 14242*da0073e9SAndroid Build Coastguard Worker 14243*da0073e9SAndroid Build Coastguard Worker self.assertEqual(test_primitive_type(1), 2) 14244*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(Exception, "Expected a value of type"): 14245*da0073e9SAndroid Build Coastguard Worker test_primitive_type(1.5) 14246*da0073e9SAndroid Build Coastguard Worker 14247*da0073e9SAndroid Build Coastguard Worker _MyNamedTuple = namedtuple('_MyNamedTuple', ['value']) 14248*da0073e9SAndroid Build Coastguard Worker 14249*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 14250*da0073e9SAndroid Build Coastguard Worker def test_non_primitive_types(x): 14251*da0073e9SAndroid Build Coastguard Worker # type: (_MyNamedTuple) -> Tensor 14252*da0073e9SAndroid Build Coastguard Worker if isinstance(1, _MyNamedTuple): 14253*da0073e9SAndroid Build Coastguard Worker return 10 14254*da0073e9SAndroid Build Coastguard Worker 14255*da0073e9SAndroid Build Coastguard Worker if isinstance(x, _MyNamedTuple): 14256*da0073e9SAndroid Build Coastguard Worker return x.value + 1 14257*da0073e9SAndroid Build Coastguard Worker else: 14258*da0073e9SAndroid Build Coastguard Worker return 1 14259*da0073e9SAndroid Build Coastguard Worker 14260*da0073e9SAndroid Build Coastguard Worker out = test_non_primitive_types(_MyNamedTuple(value=torch.tensor(5.0))) 14261*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, torch.tensor(6.0)) 14262*da0073e9SAndroid Build Coastguard Worker 14263*da0073e9SAndroid Build Coastguard Worker def test_namedtuple_type_inference(self): 14264*da0073e9SAndroid Build Coastguard Worker _AnnotatedNamedTuple = NamedTuple('_NamedTupleAnnotated', [('value', int)]) # noqa: UP014 14265*da0073e9SAndroid Build Coastguard Worker _UnannotatedNamedTuple = namedtuple('_NamedTupleUnAnnotated', ['value']) 14266*da0073e9SAndroid Build Coastguard Worker 14267*da0073e9SAndroid Build Coastguard Worker def test_check_named_tuple_value(): 14268*da0073e9SAndroid Build Coastguard Worker named_tuple = _AnnotatedNamedTuple(1) 14269*da0073e9SAndroid Build Coastguard Worker return named_tuple.value 14270*da0073e9SAndroid Build Coastguard Worker 14271*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_check_named_tuple_value, ()) 14272*da0073e9SAndroid Build Coastguard Worker 14273*da0073e9SAndroid Build Coastguard Worker def test_error(): 14274*da0073e9SAndroid Build Coastguard Worker return _UnannotatedNamedTuple(1) 14275*da0073e9SAndroid Build Coastguard Worker 14276*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, r"Expected a value of type \'Tensor \(inferred\)\' " 14277*da0073e9SAndroid Build Coastguard Worker r"for argument \'value\' but instead found type \'int\'."): 14278*da0073e9SAndroid Build Coastguard Worker torch.jit.script(test_error) 14279*da0073e9SAndroid Build Coastguard Worker 14280*da0073e9SAndroid Build Coastguard Worker def test_namedtuple_default_values_simple_type(self): 14281*da0073e9SAndroid Build Coastguard Worker 14282*da0073e9SAndroid Build Coastguard Worker class Point(NamedTuple): 14283*da0073e9SAndroid Build Coastguard Worker x: Optional[int] = None 14284*da0073e9SAndroid Build Coastguard Worker y: int = 2 14285*da0073e9SAndroid Build Coastguard Worker 14286*da0073e9SAndroid Build Coastguard Worker make_global(Point) 14287*da0073e9SAndroid Build Coastguard Worker 14288*da0073e9SAndroid Build Coastguard Worker class M(torch.nn.Module): 14289*da0073e9SAndroid Build Coastguard Worker def forward(self, point: Point): 14290*da0073e9SAndroid Build Coastguard Worker return point 14291*da0073e9SAndroid Build Coastguard Worker 14292*da0073e9SAndroid Build Coastguard Worker p = Point(x=3, y=2) 14293*da0073e9SAndroid Build Coastguard Worker 14294*da0073e9SAndroid Build Coastguard Worker self.checkModule(M(), (p,)) 14295*da0073e9SAndroid Build Coastguard Worker self.checkModule(M(), (Point(),)) 14296*da0073e9SAndroid Build Coastguard Worker 14297*da0073e9SAndroid Build Coastguard Worker m = torch.jit.script(M()) 14298*da0073e9SAndroid Build Coastguard Worker 14299*da0073e9SAndroid Build Coastguard Worker FileCheck().check(r"NamedTuple(x : int? = None, y : int = 2))") \ 14300*da0073e9SAndroid Build Coastguard Worker .run(m.graph) 14301*da0073e9SAndroid Build Coastguard Worker 14302*da0073e9SAndroid Build Coastguard Worker def test_namedtuple_default_values_missing(self): 14303*da0073e9SAndroid Build Coastguard Worker 14304*da0073e9SAndroid Build Coastguard Worker class Point(NamedTuple): 14305*da0073e9SAndroid Build Coastguard Worker x: Optional[int] 14306*da0073e9SAndroid Build Coastguard Worker y: int 14307*da0073e9SAndroid Build Coastguard Worker z: int = 3 14308*da0073e9SAndroid Build Coastguard Worker 14309*da0073e9SAndroid Build Coastguard Worker make_global(Point) 14310*da0073e9SAndroid Build Coastguard Worker 14311*da0073e9SAndroid Build Coastguard Worker class M(torch.nn.Module): 14312*da0073e9SAndroid Build Coastguard Worker def forward(self, point: Point): 14313*da0073e9SAndroid Build Coastguard Worker return point 14314*da0073e9SAndroid Build Coastguard Worker 14315*da0073e9SAndroid Build Coastguard Worker p1 = Point(x=3, y=2) 14316*da0073e9SAndroid Build Coastguard Worker p2 = Point(x=3, y=2, z=1) 14317*da0073e9SAndroid Build Coastguard Worker 14318*da0073e9SAndroid Build Coastguard Worker self.checkModule(M(), (p1,)) 14319*da0073e9SAndroid Build Coastguard Worker self.checkModule(M(), (p2,)) 14320*da0073e9SAndroid Build Coastguard Worker 14321*da0073e9SAndroid Build Coastguard Worker m = torch.jit.script(M()) 14322*da0073e9SAndroid Build Coastguard Worker 14323*da0073e9SAndroid Build Coastguard Worker FileCheck().check(r"NamedTuple(x : int?, y : int, z : int = 3))") \ 14324*da0073e9SAndroid Build Coastguard Worker .run(m.graph) 14325*da0073e9SAndroid Build Coastguard Worker 14326*da0073e9SAndroid Build Coastguard Worker def test_namedtuple_default_values_container_type(self): 14327*da0073e9SAndroid Build Coastguard Worker 14328*da0073e9SAndroid Build Coastguard Worker class Point(NamedTuple): 14329*da0073e9SAndroid Build Coastguard Worker x: Optional[List[int]] = None 14330*da0073e9SAndroid Build Coastguard Worker y: List[int] = [1, 2, 3] 14331*da0073e9SAndroid Build Coastguard Worker z: Optional[Dict[str, int]] = {"a": 1} 14332*da0073e9SAndroid Build Coastguard Worker 14333*da0073e9SAndroid Build Coastguard Worker make_global(Point) 14334*da0073e9SAndroid Build Coastguard Worker 14335*da0073e9SAndroid Build Coastguard Worker class M(torch.nn.Module): 14336*da0073e9SAndroid Build Coastguard Worker def forward(self, point: Point): 14337*da0073e9SAndroid Build Coastguard Worker return point 14338*da0073e9SAndroid Build Coastguard Worker 14339*da0073e9SAndroid Build Coastguard Worker p = Point(x=[4, 5, 6], y=[3, 2, 1], z={"b": 2}) 14340*da0073e9SAndroid Build Coastguard Worker 14341*da0073e9SAndroid Build Coastguard Worker self.checkModule(M(), (p,)) 14342*da0073e9SAndroid Build Coastguard Worker self.checkModule(M(), (Point(),)) 14343*da0073e9SAndroid Build Coastguard Worker 14344*da0073e9SAndroid Build Coastguard Worker m = torch.jit.script(M()) 14345*da0073e9SAndroid Build Coastguard Worker 14346*da0073e9SAndroid Build Coastguard Worker first_line = r"NamedTuple(x : int[]? = None, y : int[] = " \ 14347*da0073e9SAndroid Build Coastguard Worker r"[1, 2, 3], z : Dict(str, int)? = {a: 1}))" 14348*da0073e9SAndroid Build Coastguard Worker 14349*da0073e9SAndroid Build Coastguard Worker FileCheck().check(first_line) \ 14350*da0073e9SAndroid Build Coastguard Worker .run(m.graph) 14351*da0073e9SAndroid Build Coastguard Worker 14352*da0073e9SAndroid Build Coastguard Worker def test_namedtuple_default_values_Tensor_type(self): 14353*da0073e9SAndroid Build Coastguard Worker 14354*da0073e9SAndroid Build Coastguard Worker class Point(NamedTuple): 14355*da0073e9SAndroid Build Coastguard Worker x: torch.Tensor = torch.rand(2, 3) 14356*da0073e9SAndroid Build Coastguard Worker 14357*da0073e9SAndroid Build Coastguard Worker make_global(Point) 14358*da0073e9SAndroid Build Coastguard Worker 14359*da0073e9SAndroid Build Coastguard Worker class M(torch.nn.Module): 14360*da0073e9SAndroid Build Coastguard Worker def forward(self, point: Point): 14361*da0073e9SAndroid Build Coastguard Worker return point 14362*da0073e9SAndroid Build Coastguard Worker 14363*da0073e9SAndroid Build Coastguard Worker p = Point(x=torch.rand(2, 3)) 14364*da0073e9SAndroid Build Coastguard Worker 14365*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Tensors are not " 14366*da0073e9SAndroid Build Coastguard Worker "supported as default NamedTuple " 14367*da0073e9SAndroid Build Coastguard Worker "fields"): 14368*da0073e9SAndroid Build Coastguard Worker m = torch.jit.script(M()) 14369*da0073e9SAndroid Build Coastguard Worker m(p) 14370*da0073e9SAndroid Build Coastguard Worker 14371*da0073e9SAndroid Build Coastguard Worker def test_namedtuple_default_values_using_factory_constructor(self): 14372*da0073e9SAndroid Build Coastguard Worker Pair = namedtuple("Pair", ["x", "y"], defaults=(1, 2)) 14373*da0073e9SAndroid Build Coastguard Worker 14374*da0073e9SAndroid Build Coastguard Worker make_global(Pair) 14375*da0073e9SAndroid Build Coastguard Worker 14376*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 14377*da0073e9SAndroid Build Coastguard Worker def fn(x: Pair) -> Pair: 14378*da0073e9SAndroid Build Coastguard Worker return x 14379*da0073e9SAndroid Build Coastguard Worker 14380*da0073e9SAndroid Build Coastguard Worker # TODO: We can't use `checkScript` with the NamedTuple factory 14381*da0073e9SAndroid Build Coastguard Worker # constructor. Using the factory constructor with TorchScript 14382*da0073e9SAndroid Build Coastguard Worker # TorchScript creates an anonymous `NamedTuple` class instead of 14383*da0073e9SAndroid Build Coastguard Worker # preserving the actual name. For example, the actual generated 14384*da0073e9SAndroid Build Coastguard Worker # signature in this case is: 14385*da0073e9SAndroid Build Coastguard Worker # graph(%x.1 : NamedTuple(x : Tensor, y : Tensor)) 14386*da0073e9SAndroid Build Coastguard Worker # It looks like similar test cases have had this issue as well 14387*da0073e9SAndroid Build Coastguard Worker # (see: `test_namedtuple_python`). 14388*da0073e9SAndroid Build Coastguard Worker FileCheck().check(r"NamedTuple(x : Tensor = 1, y : Tensor = 2))") \ 14389*da0073e9SAndroid Build Coastguard Worker .check_next(r"return (%x.1)") \ 14390*da0073e9SAndroid Build Coastguard Worker .run(fn.graph) 14391*da0073e9SAndroid Build Coastguard Worker 14392*da0073e9SAndroid Build Coastguard Worker def test_isinstance_dynamic(self): 14393*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 14394*da0073e9SAndroid Build Coastguard Worker def foo(a): 14395*da0073e9SAndroid Build Coastguard Worker # type: (Optional[List[int]]) -> int 14396*da0073e9SAndroid Build Coastguard Worker b = 0 14397*da0073e9SAndroid Build Coastguard Worker if isinstance(a, (int, (float,), list, str)): 14398*da0073e9SAndroid Build Coastguard Worker b += 1 14399*da0073e9SAndroid Build Coastguard Worker if isinstance(a, (int, str)): 14400*da0073e9SAndroid Build Coastguard Worker b += 1 14401*da0073e9SAndroid Build Coastguard Worker if isinstance(a, List[int]): 14402*da0073e9SAndroid Build Coastguard Worker b += 1 14403*da0073e9SAndroid Build Coastguard Worker return b 14404*da0073e9SAndroid Build Coastguard Worker self.assertEqual(foo([3, 4]), 2) 14405*da0073e9SAndroid Build Coastguard Worker self.assertEqual(foo(None), 0) 14406*da0073e9SAndroid Build Coastguard Worker 14407*da0073e9SAndroid Build Coastguard Worker def test_function_overloads(self): 14408*da0073e9SAndroid Build Coastguard Worker # TODO: pyflakes currently does not compose @overload annotation with other 14409*da0073e9SAndroid Build Coastguard Worker # decorators. This is fixed on master but not on version 2.1.1. 14410*da0073e9SAndroid Build Coastguard Worker # Next version update remove noqa and add @typing.overload annotation 14411*da0073e9SAndroid Build Coastguard Worker 14412*da0073e9SAndroid Build Coastguard Worker @torch.jit._overload # noqa: F811 14413*da0073e9SAndroid Build Coastguard Worker def test_simple(x1): # noqa: F811 14414*da0073e9SAndroid Build Coastguard Worker # type: (int) -> int 14415*da0073e9SAndroid Build Coastguard Worker pass 14416*da0073e9SAndroid Build Coastguard Worker 14417*da0073e9SAndroid Build Coastguard Worker @torch.jit._overload # noqa: F811 14418*da0073e9SAndroid Build Coastguard Worker def test_simple(x1): # noqa: F811 14419*da0073e9SAndroid Build Coastguard Worker # type: (float) -> float 14420*da0073e9SAndroid Build Coastguard Worker pass 14421*da0073e9SAndroid Build Coastguard Worker 14422*da0073e9SAndroid Build Coastguard Worker def test_simple(x1): # noqa: F811 14423*da0073e9SAndroid Build Coastguard Worker return x1 14424*da0073e9SAndroid Build Coastguard Worker 14425*da0073e9SAndroid Build Coastguard Worker def invoke_function(): 14426*da0073e9SAndroid Build Coastguard Worker return test_simple(1.0), test_simple(.5) 14427*da0073e9SAndroid Build Coastguard Worker 14428*da0073e9SAndroid Build Coastguard Worker self.checkScript(invoke_function, ()) 14429*da0073e9SAndroid Build Coastguard Worker 14430*da0073e9SAndroid Build Coastguard Worker # testing that the functions are cached 14431*da0073e9SAndroid Build Coastguard Worker compiled_fns_1 = torch.jit._script._get_overloads(test_simple) 14432*da0073e9SAndroid Build Coastguard Worker compiled_fns_2 = torch.jit._script._get_overloads(test_simple) 14433*da0073e9SAndroid Build Coastguard Worker for a, b in zip(compiled_fns_1, compiled_fns_2): 14434*da0073e9SAndroid Build Coastguard Worker self.assertIs(a.graph, b.graph) 14435*da0073e9SAndroid Build Coastguard Worker 14436*da0073e9SAndroid Build Coastguard Worker old_func = test_simple 14437*da0073e9SAndroid Build Coastguard Worker 14438*da0073e9SAndroid Build Coastguard Worker # testing that new functions added work with caching 14439*da0073e9SAndroid Build Coastguard Worker @torch.jit._overload # noqa: F811 14440*da0073e9SAndroid Build Coastguard Worker def test_simple(x1): # noqa: F811 14441*da0073e9SAndroid Build Coastguard Worker # type: (str) -> str 14442*da0073e9SAndroid Build Coastguard Worker pass 14443*da0073e9SAndroid Build Coastguard Worker 14444*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 14445*da0073e9SAndroid Build Coastguard Worker def my_func(): 14446*da0073e9SAndroid Build Coastguard Worker return old_func("hi") 14447*da0073e9SAndroid Build Coastguard Worker 14448*da0073e9SAndroid Build Coastguard Worker # testing new function same qualified name 14449*da0073e9SAndroid Build Coastguard Worker @torch.jit._overload # noqa: F811 14450*da0073e9SAndroid Build Coastguard Worker def test_simple(a, b): # noqa: F811 14451*da0073e9SAndroid Build Coastguard Worker # type: (int, int) -> int 14452*da0073e9SAndroid Build Coastguard Worker pass 14453*da0073e9SAndroid Build Coastguard Worker 14454*da0073e9SAndroid Build Coastguard Worker def test_simple(a, b): 14455*da0073e9SAndroid Build Coastguard Worker return a + b 14456*da0073e9SAndroid Build Coastguard Worker 14457*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 14458*da0073e9SAndroid Build Coastguard Worker def fn(): 14459*da0073e9SAndroid Build Coastguard Worker return test_simple(3, 4) 14460*da0073e9SAndroid Build Coastguard Worker 14461*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn(), 7) 14462*da0073e9SAndroid Build Coastguard Worker 14463*da0073e9SAndroid Build Coastguard Worker # currently we take the default values have to be specified in the 14464*da0073e9SAndroid Build Coastguard Worker # overload as well - TODO take them from implementation and apply 14465*da0073e9SAndroid Build Coastguard Worker # where the type is valid. 14466*da0073e9SAndroid Build Coastguard Worker @torch.jit._overload # noqa: F811 14467*da0073e9SAndroid Build Coastguard Worker def identity(x1): # noqa: F811 14468*da0073e9SAndroid Build Coastguard Worker # type: (str) -> str 14469*da0073e9SAndroid Build Coastguard Worker pass 14470*da0073e9SAndroid Build Coastguard Worker 14471*da0073e9SAndroid Build Coastguard Worker @torch.jit._overload # noqa: F811 14472*da0073e9SAndroid Build Coastguard Worker def identity(x1): # noqa: F811 14473*da0073e9SAndroid Build Coastguard Worker # type: (float) -> float 14474*da0073e9SAndroid Build Coastguard Worker pass 14475*da0073e9SAndroid Build Coastguard Worker 14476*da0073e9SAndroid Build Coastguard Worker def identity(x1=1.0): # noqa: F811 14477*da0073e9SAndroid Build Coastguard Worker return x1 14478*da0073e9SAndroid Build Coastguard Worker 14479*da0073e9SAndroid Build Coastguard Worker def invoke(): 14480*da0073e9SAndroid Build Coastguard Worker return identity(), identity(.5), identity("hi") 14481*da0073e9SAndroid Build Coastguard Worker 14482*da0073e9SAndroid Build Coastguard Worker self.checkScript(invoke, ()) 14483*da0073e9SAndroid Build Coastguard Worker 14484*da0073e9SAndroid Build Coastguard Worker def schema_match_failure(): 14485*da0073e9SAndroid Build Coastguard Worker return identity((1, 2)) 14486*da0073e9SAndroid Build Coastguard Worker 14487*da0073e9SAndroid Build Coastguard Worker thrown = False 14488*da0073e9SAndroid Build Coastguard Worker try: 14489*da0073e9SAndroid Build Coastguard Worker torch.jit.script(schema_match_failure) 14490*da0073e9SAndroid Build Coastguard Worker except Exception as e: 14491*da0073e9SAndroid Build Coastguard Worker thrown = True 14492*da0073e9SAndroid Build Coastguard Worker self.assertTrue(r"of type 'str'" in str(e) and r"of type 'float" in str(e)) 14493*da0073e9SAndroid Build Coastguard Worker self.assertTrue(thrown) 14494*da0073e9SAndroid Build Coastguard Worker 14495*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(Exception, "cannot be directly compiled"): 14496*da0073e9SAndroid Build Coastguard Worker torch.jit.script(identity) 14497*da0073e9SAndroid Build Coastguard Worker 14498*da0073e9SAndroid Build Coastguard Worker @torch.jit._overload # noqa: F811 14499*da0073e9SAndroid Build Coastguard Worker def impl_compile_failure(x, y): # noqa: F811 14500*da0073e9SAndroid Build Coastguard Worker # type: (str, str) -> (str) 14501*da0073e9SAndroid Build Coastguard Worker pass 14502*da0073e9SAndroid Build Coastguard Worker 14503*da0073e9SAndroid Build Coastguard Worker @torch.jit._overload # noqa: F811 14504*da0073e9SAndroid Build Coastguard Worker def impl_compile_failure(x, y): # noqa: F811 14505*da0073e9SAndroid Build Coastguard Worker # type: (int, int) -> (int) 14506*da0073e9SAndroid Build Coastguard Worker pass 14507*da0073e9SAndroid Build Coastguard Worker 14508*da0073e9SAndroid Build Coastguard Worker def impl_compile_failure(x, y): # noqa: F811 14509*da0073e9SAndroid Build Coastguard Worker return x - y 14510*da0073e9SAndroid Build Coastguard Worker 14511*da0073e9SAndroid Build Coastguard Worker def test(): 14512*da0073e9SAndroid Build Coastguard Worker impl_compile_failure("one", "two") 14513*da0073e9SAndroid Build Coastguard Worker 14514*da0073e9SAndroid Build Coastguard Worker 14515*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(Exception, "Arguments for call are not valid"): 14516*da0073e9SAndroid Build Coastguard Worker torch.jit.script(test) 14517*da0073e9SAndroid Build Coastguard Worker 14518*da0073e9SAndroid Build Coastguard Worker @torch.jit._overload # noqa: F811 14519*da0073e9SAndroid Build Coastguard Worker def good_overload(x=1): # noqa: F811 14520*da0073e9SAndroid Build Coastguard Worker # type: (int) -> (int) 14521*da0073e9SAndroid Build Coastguard Worker pass 14522*da0073e9SAndroid Build Coastguard Worker 14523*da0073e9SAndroid Build Coastguard Worker def good_overload(x=1): # noqa: F811 14524*da0073e9SAndroid Build Coastguard Worker return x 14525*da0073e9SAndroid Build Coastguard Worker 14526*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 14527*da0073e9SAndroid Build Coastguard Worker def foo(): 14528*da0073e9SAndroid Build Coastguard Worker return good_overload() 14529*da0073e9SAndroid Build Coastguard Worker 14530*da0073e9SAndroid Build Coastguard Worker self.assertEqual(foo(), 1) 14531*da0073e9SAndroid Build Coastguard Worker 14532*da0073e9SAndroid Build Coastguard Worker 14533*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(Exception, "must equal to the default parameter"): 14534*da0073e9SAndroid Build Coastguard Worker @torch.jit._overload # noqa: F811 14535*da0073e9SAndroid Build Coastguard Worker def bad_default_on_overload(x, y=2): # noqa: F811 14536*da0073e9SAndroid Build Coastguard Worker # type: (int, int) -> (int) 14537*da0073e9SAndroid Build Coastguard Worker pass 14538*da0073e9SAndroid Build Coastguard Worker 14539*da0073e9SAndroid Build Coastguard Worker def bad_default_on_overload(x, y=1): # noqa: F811 14540*da0073e9SAndroid Build Coastguard Worker # type: (int, int) -> (int) 14541*da0073e9SAndroid Build Coastguard Worker pass 14542*da0073e9SAndroid Build Coastguard Worker 14543*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 14544*da0073e9SAndroid Build Coastguard Worker def test(): 14545*da0073e9SAndroid Build Coastguard Worker return bad_default_on_overload(1, 2) 14546*da0073e9SAndroid Build Coastguard Worker 14547*da0073e9SAndroid Build Coastguard Worker @torch.jit._overload # noqa: F811 14548*da0073e9SAndroid Build Coastguard Worker def diff_default(x): # noqa: F811 14549*da0073e9SAndroid Build Coastguard Worker # type: (int) -> int 14550*da0073e9SAndroid Build Coastguard Worker pass 14551*da0073e9SAndroid Build Coastguard Worker 14552*da0073e9SAndroid Build Coastguard Worker @torch.jit._overload # noqa: F811 14553*da0073e9SAndroid Build Coastguard Worker def diff_default(x): # noqa: F811 14554*da0073e9SAndroid Build Coastguard Worker # type: (str) -> str 14555*da0073e9SAndroid Build Coastguard Worker pass 14556*da0073e9SAndroid Build Coastguard Worker 14557*da0073e9SAndroid Build Coastguard Worker def diff_default(x="hi"): # noqa: F811 14558*da0073e9SAndroid Build Coastguard Worker return x 14559*da0073e9SAndroid Build Coastguard Worker 14560*da0073e9SAndroid Build Coastguard Worker def test(): 14561*da0073e9SAndroid Build Coastguard Worker return diff_default(), diff_default(2), diff_default("abc") 14562*da0073e9SAndroid Build Coastguard Worker 14563*da0073e9SAndroid Build Coastguard Worker self.assertEqual(test(), torch.jit.script(test)()) 14564*da0073e9SAndroid Build Coastguard Worker 14565*da0073e9SAndroid Build Coastguard Worker @torch.jit._overload # noqa: F811 14566*da0073e9SAndroid Build Coastguard Worker def diff_num_params(x): # noqa: F811 14567*da0073e9SAndroid Build Coastguard Worker # type: (float) -> float 14568*da0073e9SAndroid Build Coastguard Worker pass 14569*da0073e9SAndroid Build Coastguard Worker 14570*da0073e9SAndroid Build Coastguard Worker @torch.jit._overload # noqa: F811 14571*da0073e9SAndroid Build Coastguard Worker def diff_num_params(x, y): # noqa: F811 14572*da0073e9SAndroid Build Coastguard Worker # type: (int, int) -> int 14573*da0073e9SAndroid Build Coastguard Worker pass 14574*da0073e9SAndroid Build Coastguard Worker 14575*da0073e9SAndroid Build Coastguard Worker def diff_num_params(x, y=2, z=3): # noqa: F811 14576*da0073e9SAndroid Build Coastguard Worker # type: (Union[float, int], int, int) 14577*da0073e9SAndroid Build Coastguard Worker return x + y + z 14578*da0073e9SAndroid Build Coastguard Worker 14579*da0073e9SAndroid Build Coastguard Worker def test(): 14580*da0073e9SAndroid Build Coastguard Worker return diff_num_params(1.0), diff_num_params(1, 2), diff_num_params(1), diff_num_params(1, 2, 3) 14581*da0073e9SAndroid Build Coastguard Worker 14582*da0073e9SAndroid Build Coastguard Worker self.assertEqual(test(), torch.jit.script(test)()) 14583*da0073e9SAndroid Build Coastguard Worker 14584*da0073e9SAndroid Build Coastguard Worker @torch.jit._overload # noqa: F811 14585*da0073e9SAndroid Build Coastguard Worker def diff_num_params_no_annot(): 14586*da0073e9SAndroid Build Coastguard Worker # type: () -> int 14587*da0073e9SAndroid Build Coastguard Worker pass 14588*da0073e9SAndroid Build Coastguard Worker 14589*da0073e9SAndroid Build Coastguard Worker def diff_num_params_no_annot(x=1): # noqa: F811 14590*da0073e9SAndroid Build Coastguard Worker return x 14591*da0073e9SAndroid Build Coastguard Worker 14592*da0073e9SAndroid Build Coastguard Worker def test(): 14593*da0073e9SAndroid Build Coastguard Worker return diff_num_params_no_annot(1.0) 14594*da0073e9SAndroid Build Coastguard Worker 14595*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(Exception, "Parameters not specified"): 14596*da0073e9SAndroid Build Coastguard Worker torch.jit.script(test) 14597*da0073e9SAndroid Build Coastguard Worker 14598*da0073e9SAndroid Build Coastguard Worker def test_function_overload_misuse(self): 14599*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Only `pass` statement or `...` can be the body"): 14600*da0073e9SAndroid Build Coastguard Worker @torch.jit._overload 14601*da0073e9SAndroid Build Coastguard Worker def wrong_decl_body(x: str) -> str: 14602*da0073e9SAndroid Build Coastguard Worker return x + "0" 14603*da0073e9SAndroid Build Coastguard Worker 14604*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Only `pass` statement or `...` can be the body"): 14605*da0073e9SAndroid Build Coastguard Worker class MyClass: 14606*da0073e9SAndroid Build Coastguard Worker @torch.jit._overload_method 14607*da0073e9SAndroid Build Coastguard Worker def method(self): 14608*da0073e9SAndroid Build Coastguard Worker return 0 14609*da0073e9SAndroid Build Coastguard Worker 14610*da0073e9SAndroid Build Coastguard Worker @torch.jit._overload 14611*da0073e9SAndroid Build Coastguard Worker def null_overload(x: int) -> int: ... # noqa: E704 14612*da0073e9SAndroid Build Coastguard Worker 14613*da0073e9SAndroid Build Coastguard Worker @torch.jit._overload # noqa: F811 14614*da0073e9SAndroid Build Coastguard Worker def null_overload(x: str) -> str: # noqa: F811 14615*da0073e9SAndroid Build Coastguard Worker pass 14616*da0073e9SAndroid Build Coastguard Worker 14617*da0073e9SAndroid Build Coastguard Worker def null_overload_driver(): 14618*da0073e9SAndroid Build Coastguard Worker return null_overload(0) 14619*da0073e9SAndroid Build Coastguard Worker 14620*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'Implementation for the function ".+" is missing.'): 14621*da0073e9SAndroid Build Coastguard Worker torch.jit.script(null_overload_driver) 14622*da0073e9SAndroid Build Coastguard Worker 14623*da0073e9SAndroid Build Coastguard Worker class OverloadMisuse(torch.nn.Module): 14624*da0073e9SAndroid Build Coastguard Worker @torch.jit._overload_method 14625*da0073e9SAndroid Build Coastguard Worker def forward(self, x: int): 14626*da0073e9SAndroid Build Coastguard Worker pass 14627*da0073e9SAndroid Build Coastguard Worker 14628*da0073e9SAndroid Build Coastguard Worker @torch.jit._overload_method # noqa: F811 14629*da0073e9SAndroid Build Coastguard Worker def forward(self, x: Tensor): # noqa: F811 14630*da0073e9SAndroid Build Coastguard Worker pass 14631*da0073e9SAndroid Build Coastguard Worker 14632*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'Implementation for the method ".+" is missing.'): 14633*da0073e9SAndroid Build Coastguard Worker m = torch.jit.script(OverloadMisuse()) 14634*da0073e9SAndroid Build Coastguard Worker 14635*da0073e9SAndroid Build Coastguard Worker 14636*da0073e9SAndroid Build Coastguard Worker def test_script_method_torch_function_overload(self): 14637*da0073e9SAndroid Build Coastguard Worker class MyCustomTensor(torch.Tensor): 14638*da0073e9SAndroid Build Coastguard Worker pass 14639*da0073e9SAndroid Build Coastguard Worker 14640*da0073e9SAndroid Build Coastguard Worker class MyCustomModule(torch.nn.Module): 14641*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 14642*da0073e9SAndroid Build Coastguard Worker return torch.relu(x) 14643*da0073e9SAndroid Build Coastguard Worker 14644*da0073e9SAndroid Build Coastguard Worker scripted_mod = torch.jit.script(MyCustomModule()) 14645*da0073e9SAndroid Build Coastguard Worker t = torch.tensor([3.0]) 14646*da0073e9SAndroid Build Coastguard Worker ref_out = scripted_mod(t) 14647*da0073e9SAndroid Build Coastguard Worker 14648*da0073e9SAndroid Build Coastguard Worker t_custom = MyCustomTensor([3.0]) 14649*da0073e9SAndroid Build Coastguard Worker out1 = scripted_mod(t_custom) 14650*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out1, ref_out) 14651*da0073e9SAndroid Build Coastguard Worker 14652*da0073e9SAndroid Build Coastguard Worker out2 = scripted_mod.forward(t_custom) 14653*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out2, ref_out) 14654*da0073e9SAndroid Build Coastguard Worker 14655*da0073e9SAndroid Build Coastguard Worker def test_function_overloading_isinstance(self): 14656*da0073e9SAndroid Build Coastguard Worker @torch.jit._overload # noqa: F811 14657*da0073e9SAndroid Build Coastguard Worker def my_conv(x, y): # noqa: F811 14658*da0073e9SAndroid Build Coastguard Worker # type: (float, str) -> (float) 14659*da0073e9SAndroid Build Coastguard Worker pass 14660*da0073e9SAndroid Build Coastguard Worker 14661*da0073e9SAndroid Build Coastguard Worker @torch.jit._overload # noqa: F811 14662*da0073e9SAndroid Build Coastguard Worker def my_conv(x, y): # noqa: F811 14663*da0073e9SAndroid Build Coastguard Worker # type: (float, float) -> (float) 14664*da0073e9SAndroid Build Coastguard Worker pass 14665*da0073e9SAndroid Build Coastguard Worker 14666*da0073e9SAndroid Build Coastguard Worker def my_conv(x, y=2.0): # noqa: F811 14667*da0073e9SAndroid Build Coastguard Worker if isinstance(y, str): 14668*da0073e9SAndroid Build Coastguard Worker if y == "hi": 14669*da0073e9SAndroid Build Coastguard Worker return 4.0 - x 14670*da0073e9SAndroid Build Coastguard Worker else: 14671*da0073e9SAndroid Build Coastguard Worker return 5.0 - x 14672*da0073e9SAndroid Build Coastguard Worker else: 14673*da0073e9SAndroid Build Coastguard Worker return 2.0 + x 14674*da0073e9SAndroid Build Coastguard Worker 14675*da0073e9SAndroid Build Coastguard Worker def test_uses(): 14676*da0073e9SAndroid Build Coastguard Worker return my_conv(1.5), my_conv(1.5, "hi"), my_conv(1.5, 5.0) 14677*da0073e9SAndroid Build Coastguard Worker 14678*da0073e9SAndroid Build Coastguard Worker self.checkScript(test_uses, ()) 14679*da0073e9SAndroid Build Coastguard Worker 14680*da0073e9SAndroid Build Coastguard Worker def test_method_overloading(self): 14681*da0073e9SAndroid Build Coastguard Worker class Over(torch.nn.Module): 14682*da0073e9SAndroid Build Coastguard Worker @torch.jit._overload_method # noqa: F811 14683*da0073e9SAndroid Build Coastguard Worker def forward(self, x): # noqa: F811 14684*da0073e9SAndroid Build Coastguard Worker # type: (Tuple[Tensor, Tensor]) -> Tensor 14685*da0073e9SAndroid Build Coastguard Worker pass 14686*da0073e9SAndroid Build Coastguard Worker 14687*da0073e9SAndroid Build Coastguard Worker @torch.jit._overload_method # noqa: F811 14688*da0073e9SAndroid Build Coastguard Worker def forward(self, x): # noqa: F811 14689*da0073e9SAndroid Build Coastguard Worker # type: (Tensor) -> Tensor 14690*da0073e9SAndroid Build Coastguard Worker pass 14691*da0073e9SAndroid Build Coastguard Worker 14692*da0073e9SAndroid Build Coastguard Worker def forward(self, x): # noqa: F811 14693*da0073e9SAndroid Build Coastguard Worker if isinstance(x, Tensor): 14694*da0073e9SAndroid Build Coastguard Worker return x + 20 14695*da0073e9SAndroid Build Coastguard Worker else: 14696*da0073e9SAndroid Build Coastguard Worker return x[0] + 5 14697*da0073e9SAndroid Build Coastguard Worker 14698*da0073e9SAndroid Build Coastguard Worker class S(torch.jit.ScriptModule): 14699*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 14700*da0073e9SAndroid Build Coastguard Worker super().__init__() 14701*da0073e9SAndroid Build Coastguard Worker self.weak = Over() 14702*da0073e9SAndroid Build Coastguard Worker 14703*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 14704*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 14705*da0073e9SAndroid Build Coastguard Worker return self.weak(x) + self.weak((x, x)) 14706*da0073e9SAndroid Build Coastguard Worker 14707*da0073e9SAndroid Build Coastguard Worker s_mod = S() 14708*da0073e9SAndroid Build Coastguard Worker x = torch.ones(1) 14709*da0073e9SAndroid Build Coastguard Worker self.assertEqual(s_mod(x), x + 20 + 5 + x) 14710*da0073e9SAndroid Build Coastguard Worker 14711*da0073e9SAndroid Build Coastguard Worker over = Over() 14712*da0073e9SAndroid Build Coastguard Worker self.assertEqual(over((x, x)), x + 5) 14713*da0073e9SAndroid Build Coastguard Worker self.assertEqual(over(x), x + 20) 14714*da0073e9SAndroid Build Coastguard Worker 14715*da0073e9SAndroid Build Coastguard Worker class Unannotated(torch.nn.Module): 14716*da0073e9SAndroid Build Coastguard Worker @torch.jit._overload_method # noqa: F811 14717*da0073e9SAndroid Build Coastguard Worker def hello(self, x): # noqa: F811 14718*da0073e9SAndroid Build Coastguard Worker pass 14719*da0073e9SAndroid Build Coastguard Worker 14720*da0073e9SAndroid Build Coastguard Worker @torch.jit._overload_method # noqa: F811 14721*da0073e9SAndroid Build Coastguard Worker def hello(self, x): # noqa: F811 14722*da0073e9SAndroid Build Coastguard Worker # type: (int) -> (int) 14723*da0073e9SAndroid Build Coastguard Worker pass 14724*da0073e9SAndroid Build Coastguard Worker 14725*da0073e9SAndroid Build Coastguard Worker def hello(self, x): # noqa: F811 14726*da0073e9SAndroid Build Coastguard Worker return x + 3 14727*da0073e9SAndroid Build Coastguard Worker 14728*da0073e9SAndroid Build Coastguard Worker def forward(self): 14729*da0073e9SAndroid Build Coastguard Worker return self.hello(1), self.hello(.5) 14730*da0073e9SAndroid Build Coastguard Worker 14731*da0073e9SAndroid Build Coastguard Worker w = Unannotated() 14732*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(Exception, "explicitly add type annotations to overloaded functions"): 14733*da0073e9SAndroid Build Coastguard Worker torch.jit.script(w) 14734*da0073e9SAndroid Build Coastguard Worker 14735*da0073e9SAndroid Build Coastguard Worker class CompileOverloadError(torch.nn.Module): 14736*da0073e9SAndroid Build Coastguard Worker @torch.jit._overload_method # noqa: F811 14737*da0073e9SAndroid Build Coastguard Worker def hello(self, x): # noqa: F811 14738*da0073e9SAndroid Build Coastguard Worker # type: (str) -> (int) 14739*da0073e9SAndroid Build Coastguard Worker pass 14740*da0073e9SAndroid Build Coastguard Worker 14741*da0073e9SAndroid Build Coastguard Worker @torch.jit._overload_method # noqa: F811 14742*da0073e9SAndroid Build Coastguard Worker def hello(self, x): # noqa: F811 14743*da0073e9SAndroid Build Coastguard Worker # type: (int) -> (int) 14744*da0073e9SAndroid Build Coastguard Worker pass 14745*da0073e9SAndroid Build Coastguard Worker 14746*da0073e9SAndroid Build Coastguard Worker def hello(self, x): # noqa: F811 14747*da0073e9SAndroid Build Coastguard Worker return x + 1 14748*da0073e9SAndroid Build Coastguard Worker 14749*da0073e9SAndroid Build Coastguard Worker def forward(self): 14750*da0073e9SAndroid Build Coastguard Worker return self.hello("hi"), self.hello(.5) 14751*da0073e9SAndroid Build Coastguard Worker 14752*da0073e9SAndroid Build Coastguard Worker w = CompileOverloadError() 14753*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(Exception, "but instead found type 'str'"): 14754*da0073e9SAndroid Build Coastguard Worker torch.jit.script(w) 14755*da0073e9SAndroid Build Coastguard Worker 14756*da0073e9SAndroid Build Coastguard Worker # testing overload declared first, then non-overload 14757*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(Exception, "Overloads are not useable when a module"): 14758*da0073e9SAndroid Build Coastguard Worker class W3(torch.nn.Module): 14759*da0073e9SAndroid Build Coastguard Worker @torch.jit._overload_method # noqa: F811 14760*da0073e9SAndroid Build Coastguard Worker def forward(self, x): # noqa: F811 14761*da0073e9SAndroid Build Coastguard Worker # type: (int) -> int 14762*da0073e9SAndroid Build Coastguard Worker pass 14763*da0073e9SAndroid Build Coastguard Worker 14764*da0073e9SAndroid Build Coastguard Worker @torch.jit._overload_method # noqa: F811 14765*da0073e9SAndroid Build Coastguard Worker def forward(self, x): # noqa: F811 14766*da0073e9SAndroid Build Coastguard Worker # type: (Tensor) -> Tensor 14767*da0073e9SAndroid Build Coastguard Worker pass 14768*da0073e9SAndroid Build Coastguard Worker 14769*da0073e9SAndroid Build Coastguard Worker def forward(self, x): # noqa: F811 14770*da0073e9SAndroid Build Coastguard Worker return x + 5 14771*da0073e9SAndroid Build Coastguard Worker 14772*da0073e9SAndroid Build Coastguard Worker a = W3() 14773*da0073e9SAndroid Build Coastguard Worker b = torch.jit.script(a) 14774*da0073e9SAndroid Build Coastguard Worker 14775*da0073e9SAndroid Build Coastguard Worker class W3(torch.nn.Module): 14776*da0073e9SAndroid Build Coastguard Worker def forward(self, x): # noqa: F811 14777*da0073e9SAndroid Build Coastguard Worker return x + 5 + 10 14778*da0073e9SAndroid Build Coastguard Worker 14779*da0073e9SAndroid Build Coastguard Worker a = W3() 14780*da0073e9SAndroid Build Coastguard Worker b = torch.jit.script(a) 14781*da0073e9SAndroid Build Coastguard Worker 14782*da0073e9SAndroid Build Coastguard Worker # testing non-overload declared first, then overload 14783*da0073e9SAndroid Build Coastguard Worker class W2(torch.nn.Module): 14784*da0073e9SAndroid Build Coastguard Worker def hello(self, x1, x2): 14785*da0073e9SAndroid Build Coastguard Worker return x1 + x2 14786*da0073e9SAndroid Build Coastguard Worker 14787*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 14788*da0073e9SAndroid Build Coastguard Worker return self.hello(x, x) 14789*da0073e9SAndroid Build Coastguard Worker 14790*da0073e9SAndroid Build Coastguard Worker a = torch.jit.script(W2()) 14791*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a(torch.tensor(1)), torch.tensor(2)) 14792*da0073e9SAndroid Build Coastguard Worker 14793*da0073e9SAndroid Build Coastguard Worker class W2(torch.nn.Module): 14794*da0073e9SAndroid Build Coastguard Worker @torch.jit._overload_method # noqa: F811 14795*da0073e9SAndroid Build Coastguard Worker def hello(self, x): # noqa: F811 14796*da0073e9SAndroid Build Coastguard Worker pass 14797*da0073e9SAndroid Build Coastguard Worker 14798*da0073e9SAndroid Build Coastguard Worker @torch.jit._overload_method # noqa: F811 14799*da0073e9SAndroid Build Coastguard Worker def hello(self, x): # noqa: F811 14800*da0073e9SAndroid Build Coastguard Worker # type: (int) -> (int) 14801*da0073e9SAndroid Build Coastguard Worker pass 14802*da0073e9SAndroid Build Coastguard Worker 14803*da0073e9SAndroid Build Coastguard Worker def hello(self, x): # noqa: F811 14804*da0073e9SAndroid Build Coastguard Worker return x + 5 + 10 14805*da0073e9SAndroid Build Coastguard Worker 14806*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 14807*da0073e9SAndroid Build Coastguard Worker return self.hello(1), self.hello(x) 14808*da0073e9SAndroid Build Coastguard Worker 14809*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(Exception, "Overloads are not useable when a module"): 14810*da0073e9SAndroid Build Coastguard Worker a = torch.jit.script(W2()) 14811*da0073e9SAndroid Build Coastguard Worker 14812*da0073e9SAndroid Build Coastguard Worker def test_narrow_copy(self): 14813*da0073e9SAndroid Build Coastguard Worker def foo(a): 14814*da0073e9SAndroid Build Coastguard Worker return a.narrow_copy(0, 0, 5) 14815*da0073e9SAndroid Build Coastguard Worker 14816*da0073e9SAndroid Build Coastguard Worker self.checkScript(foo, [torch.rand(10)]) 14817*da0073e9SAndroid Build Coastguard Worker 14818*da0073e9SAndroid Build Coastguard Worker def test_select_after_chunk(self): 14819*da0073e9SAndroid Build Coastguard Worker def foo(x): 14820*da0073e9SAndroid Build Coastguard Worker chunked = torch.chunk(x, 1) 14821*da0073e9SAndroid Build Coastguard Worker foo = chunked[0] 14822*da0073e9SAndroid Build Coastguard Worker foo.add_(5) 14823*da0073e9SAndroid Build Coastguard Worker return x 14824*da0073e9SAndroid Build Coastguard Worker 14825*da0073e9SAndroid Build Coastguard Worker self.checkScript(foo, [torch.rand(2, 3)]) 14826*da0073e9SAndroid Build Coastguard Worker 14827*da0073e9SAndroid Build Coastguard Worker def test_nn_LSTM_with_layers(self): 14828*da0073e9SAndroid Build Coastguard Worker class M(torch.jit.ScriptModule): 14829*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 14830*da0073e9SAndroid Build Coastguard Worker super().__init__() 14831*da0073e9SAndroid Build Coastguard Worker self.rnn = nn.LSTM(2, 3, 2, dropout=0) 14832*da0073e9SAndroid Build Coastguard Worker 14833*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 14834*da0073e9SAndroid Build Coastguard Worker def forward(self, x, lengths, h0, c0): 14835*da0073e9SAndroid Build Coastguard Worker return self.rnn(x, (h0, c0))[0] 14836*da0073e9SAndroid Build Coastguard Worker 14837*da0073e9SAndroid Build Coastguard Worker class Eager(torch.nn.Module): 14838*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 14839*da0073e9SAndroid Build Coastguard Worker super().__init__() 14840*da0073e9SAndroid Build Coastguard Worker self.rnn = nn.LSTM(2, 3, 2, dropout=0) 14841*da0073e9SAndroid Build Coastguard Worker 14842*da0073e9SAndroid Build Coastguard Worker def forward(self, x, lengths, h0, c0): 14843*da0073e9SAndroid Build Coastguard Worker return self.rnn(x, (h0, c0))[0] 14844*da0073e9SAndroid Build Coastguard Worker 14845*da0073e9SAndroid Build Coastguard Worker inputs = (torch.randn(1, 1, 2), torch.LongTensor([7]), torch.randn(2, 1, 3), torch.randn(2, 1, 3)) 14846*da0073e9SAndroid Build Coastguard Worker eager_out = self.runAndSaveRNG(lambda: Eager()(*inputs), ())[0] 14847*da0073e9SAndroid Build Coastguard Worker script_out = self.runAndSaveRNG(lambda: M()(*inputs), ())[0] 14848*da0073e9SAndroid Build Coastguard Worker 14849*da0073e9SAndroid Build Coastguard Worker self.assertEqual(eager_out, script_out) 14850*da0073e9SAndroid Build Coastguard Worker 14851*da0073e9SAndroid Build Coastguard Worker def test_nn_LSTM(self): 14852*da0073e9SAndroid Build Coastguard Worker input = torch.nn.utils.rnn.pack_sequence([torch.randn(5, 5)]) 14853*da0073e9SAndroid Build Coastguard Worker 14854*da0073e9SAndroid Build Coastguard Worker class S(torch.jit.ScriptModule): 14855*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 14856*da0073e9SAndroid Build Coastguard Worker super().__init__() 14857*da0073e9SAndroid Build Coastguard Worker self.x = torch.nn.LSTM(5, 5) 14858*da0073e9SAndroid Build Coastguard Worker 14859*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 14860*da0073e9SAndroid Build Coastguard Worker def forward(self, input: PackedSequence) -> Tuple[PackedSequence, Tuple[torch.Tensor, torch.Tensor]]: 14861*da0073e9SAndroid Build Coastguard Worker return self.x(input) 14862*da0073e9SAndroid Build Coastguard Worker 14863*da0073e9SAndroid Build Coastguard Worker eager_out = self.runAndSaveRNG(lambda x: torch.nn.LSTM(5, 5)(x), (input,))[0] 14864*da0073e9SAndroid Build Coastguard Worker script_out = self.runAndSaveRNG(lambda x: S()(x), (input,))[0] 14865*da0073e9SAndroid Build Coastguard Worker 14866*da0073e9SAndroid Build Coastguard Worker self.assertEqual(eager_out, script_out) 14867*da0073e9SAndroid Build Coastguard Worker 14868*da0073e9SAndroid Build Coastguard Worker def test_nn_GRU(self): 14869*da0073e9SAndroid Build Coastguard Worker seq_input = torch.nn.utils.rnn.pack_sequence([torch.randn(5, 5)]) 14870*da0073e9SAndroid Build Coastguard Worker tensor_input = torch.randn(5, 5, 5) 14871*da0073e9SAndroid Build Coastguard Worker 14872*da0073e9SAndroid Build Coastguard Worker class SeqLengthGRU(torch.jit.ScriptModule): 14873*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 14874*da0073e9SAndroid Build Coastguard Worker super().__init__() 14875*da0073e9SAndroid Build Coastguard Worker self.x = torch.nn.GRU(5, 5) 14876*da0073e9SAndroid Build Coastguard Worker 14877*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 14878*da0073e9SAndroid Build Coastguard Worker def forward(self, input: PackedSequence) -> Tuple[PackedSequence, torch.Tensor]: 14879*da0073e9SAndroid Build Coastguard Worker return self.x(input) 14880*da0073e9SAndroid Build Coastguard Worker 14881*da0073e9SAndroid Build Coastguard Worker class TensorGRU(torch.jit.ScriptModule): 14882*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 14883*da0073e9SAndroid Build Coastguard Worker super().__init__() 14884*da0073e9SAndroid Build Coastguard Worker self.x = torch.nn.GRU(5, 5) 14885*da0073e9SAndroid Build Coastguard Worker 14886*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 14887*da0073e9SAndroid Build Coastguard Worker def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 14888*da0073e9SAndroid Build Coastguard Worker return self.x(input) 14889*da0073e9SAndroid Build Coastguard Worker 14890*da0073e9SAndroid Build Coastguard Worker seq_eager_out = self.runAndSaveRNG(lambda x: torch.nn.GRU(5, 5)(x), (seq_input,))[0] 14891*da0073e9SAndroid Build Coastguard Worker seq_script_out = self.runAndSaveRNG(lambda x: SeqLengthGRU()(x), (seq_input,))[0] 14892*da0073e9SAndroid Build Coastguard Worker tensor_eager_out = self.runAndSaveRNG(lambda x: torch.nn.GRU(5, 5)(x), (tensor_input,))[0] 14893*da0073e9SAndroid Build Coastguard Worker tensor_script_out = self.runAndSaveRNG(lambda x: TensorGRU()(x), (tensor_input,))[0] 14894*da0073e9SAndroid Build Coastguard Worker 14895*da0073e9SAndroid Build Coastguard Worker self.assertEqual(seq_eager_out, seq_script_out) 14896*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor_eager_out, tensor_script_out) 14897*da0073e9SAndroid Build Coastguard Worker 14898*da0073e9SAndroid Build Coastguard Worker def test_torchscript_memoryformat(self): 14899*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 14900*da0073e9SAndroid Build Coastguard Worker def fn(x): 14901*da0073e9SAndroid Build Coastguard Worker return x.contiguous(memory_format=torch.channels_last) 14902*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4, 3, 6, 6) 14903*da0073e9SAndroid Build Coastguard Worker y = fn(x) 14904*da0073e9SAndroid Build Coastguard Worker self.assertTrue(y.is_contiguous(memory_format=torch.channels_last)) 14905*da0073e9SAndroid Build Coastguard Worker 14906*da0073e9SAndroid Build Coastguard Worker def test_torchscript_multi_head_attn(self): 14907*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 14908*da0073e9SAndroid Build Coastguard Worker def jit_multihead_attn_forward(query, # type: Tensor 14909*da0073e9SAndroid Build Coastguard Worker key, # type: Tensor 14910*da0073e9SAndroid Build Coastguard Worker value, # type: Tensor 14911*da0073e9SAndroid Build Coastguard Worker embed_dim_to_check, # type: int 14912*da0073e9SAndroid Build Coastguard Worker num_heads, # type: int 14913*da0073e9SAndroid Build Coastguard Worker in_proj_weight, # type: Tensor 14914*da0073e9SAndroid Build Coastguard Worker in_proj_bias, # type: Tensor 14915*da0073e9SAndroid Build Coastguard Worker bias_k, # type: Optional[Tensor] 14916*da0073e9SAndroid Build Coastguard Worker bias_v, # type: Optional[Tensor] 14917*da0073e9SAndroid Build Coastguard Worker add_zero_attn, # type: bool 14918*da0073e9SAndroid Build Coastguard Worker dropout, # type: float 14919*da0073e9SAndroid Build Coastguard Worker out_proj_weight, # type: Tensor 14920*da0073e9SAndroid Build Coastguard Worker out_proj_bias, # type: Tensor 14921*da0073e9SAndroid Build Coastguard Worker training=True, # type: bool 14922*da0073e9SAndroid Build Coastguard Worker key_padding_mask=None, # type: Optional[Tensor] 14923*da0073e9SAndroid Build Coastguard Worker need_weights=True, # type: bool 14924*da0073e9SAndroid Build Coastguard Worker attn_mask=None # type: Optional[Tensor] 14925*da0073e9SAndroid Build Coastguard Worker ): 14926*da0073e9SAndroid Build Coastguard Worker # type: (...) -> Tuple[Tensor, Optional[Tensor]] 14927*da0073e9SAndroid Build Coastguard Worker return torch.nn.functional.multi_head_attention_forward(query, key, value, 14928*da0073e9SAndroid Build Coastguard Worker embed_dim_to_check, num_heads, 14929*da0073e9SAndroid Build Coastguard Worker in_proj_weight, in_proj_bias, 14930*da0073e9SAndroid Build Coastguard Worker bias_k, bias_v, 14931*da0073e9SAndroid Build Coastguard Worker add_zero_attn, dropout, 14932*da0073e9SAndroid Build Coastguard Worker out_proj_weight, out_proj_bias, 14933*da0073e9SAndroid Build Coastguard Worker training, key_padding_mask, 14934*da0073e9SAndroid Build Coastguard Worker need_weights, attn_mask) 14935*da0073e9SAndroid Build Coastguard Worker 14936*da0073e9SAndroid Build Coastguard Worker src_l = 3 14937*da0073e9SAndroid Build Coastguard Worker bsz = 5 14938*da0073e9SAndroid Build Coastguard Worker embed_size = 8 14939*da0073e9SAndroid Build Coastguard Worker nhead = 2 14940*da0073e9SAndroid Build Coastguard Worker multi_head_attn = torch.nn.MultiheadAttention(embed_size, nhead) 14941*da0073e9SAndroid Build Coastguard Worker query = torch.rand((src_l, bsz, embed_size)) 14942*da0073e9SAndroid Build Coastguard Worker key = torch.rand((src_l, bsz, embed_size)) 14943*da0073e9SAndroid Build Coastguard Worker value = torch.rand((src_l, bsz, embed_size)) 14944*da0073e9SAndroid Build Coastguard Worker 14945*da0073e9SAndroid Build Coastguard Worker mask = (torch.triu(torch.ones(src_l, src_l)) == 1).transpose(0, 1) 14946*da0073e9SAndroid Build Coastguard Worker mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, 0.0).to(torch.get_default_dtype()) 14947*da0073e9SAndroid Build Coastguard Worker 14948*da0073e9SAndroid Build Coastguard Worker jit_out = jit_multihead_attn_forward(query, key, value, 14949*da0073e9SAndroid Build Coastguard Worker embed_size, nhead, 14950*da0073e9SAndroid Build Coastguard Worker multi_head_attn.in_proj_weight, 14951*da0073e9SAndroid Build Coastguard Worker multi_head_attn.in_proj_bias, 14952*da0073e9SAndroid Build Coastguard Worker multi_head_attn.bias_k, multi_head_attn.bias_v, 14953*da0073e9SAndroid Build Coastguard Worker multi_head_attn.add_zero_attn, multi_head_attn.dropout, 14954*da0073e9SAndroid Build Coastguard Worker multi_head_attn.out_proj.weight, 14955*da0073e9SAndroid Build Coastguard Worker multi_head_attn.out_proj.bias, attn_mask=mask)[0] 14956*da0073e9SAndroid Build Coastguard Worker 14957*da0073e9SAndroid Build Coastguard Worker py_out = torch.nn.functional.multi_head_attention_forward(query, key, value, 14958*da0073e9SAndroid Build Coastguard Worker embed_size, nhead, 14959*da0073e9SAndroid Build Coastguard Worker multi_head_attn.in_proj_weight, 14960*da0073e9SAndroid Build Coastguard Worker multi_head_attn.in_proj_bias, 14961*da0073e9SAndroid Build Coastguard Worker multi_head_attn.bias_k, 14962*da0073e9SAndroid Build Coastguard Worker multi_head_attn.bias_v, 14963*da0073e9SAndroid Build Coastguard Worker multi_head_attn.add_zero_attn, 14964*da0073e9SAndroid Build Coastguard Worker multi_head_attn.dropout, 14965*da0073e9SAndroid Build Coastguard Worker multi_head_attn.out_proj.weight, 14966*da0073e9SAndroid Build Coastguard Worker multi_head_attn.out_proj.bias, 14967*da0073e9SAndroid Build Coastguard Worker attn_mask=mask)[0] 14968*da0073e9SAndroid Build Coastguard Worker # print("rel. error: ") 14969*da0073e9SAndroid Build Coastguard Worker # print(jit_out / py_out - 1) 14970*da0073e9SAndroid Build Coastguard Worker self.assertEqual(jit_out, py_out, atol=5e-4, rtol=1e-4) 14971*da0073e9SAndroid Build Coastguard Worker 14972*da0073e9SAndroid Build Coastguard Worker def test_torchscript_multi_head_attn_fast_path(self): 14973*da0073e9SAndroid Build Coastguard Worker src_l = 3 14974*da0073e9SAndroid Build Coastguard Worker bsz = 5 14975*da0073e9SAndroid Build Coastguard Worker embed_size = 8 14976*da0073e9SAndroid Build Coastguard Worker nhead = 2 14977*da0073e9SAndroid Build Coastguard Worker multi_head_attn = torch.nn.MultiheadAttention(embed_size, nhead, batch_first=True) 14978*da0073e9SAndroid Build Coastguard Worker multi_head_attn = multi_head_attn.eval() 14979*da0073e9SAndroid Build Coastguard Worker 14980*da0073e9SAndroid Build Coastguard Worker query = key = value = torch.rand((bsz, src_l, embed_size)) 14981*da0073e9SAndroid Build Coastguard Worker 14982*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 14983*da0073e9SAndroid Build Coastguard Worker py_out = multi_head_attn(query, key, value) 14984*da0073e9SAndroid Build Coastguard Worker mha = torch.jit.script(multi_head_attn) 14985*da0073e9SAndroid Build Coastguard Worker jit_out = mha(query, key, value) 14986*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(jit_out, py_out) 14987*da0073e9SAndroid Build Coastguard Worker 14988*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not RUN_CUDA, "no CUDA") 14989*da0073e9SAndroid Build Coastguard Worker def test_scriptmodule_multi_head_attn_cuda(self): 14990*da0073e9SAndroid Build Coastguard Worker 14991*da0073e9SAndroid Build Coastguard Worker class MyModule(torch.jit.ScriptModule): 14992*da0073e9SAndroid Build Coastguard Worker def __init__(self, embed_dim, num_heads): 14993*da0073e9SAndroid Build Coastguard Worker super().__init__() 14994*da0073e9SAndroid Build Coastguard Worker sample_q = torch.randn(3, 2, embed_dim) 14995*da0073e9SAndroid Build Coastguard Worker sample_kv = torch.randn(3, 2, embed_dim) 14996*da0073e9SAndroid Build Coastguard Worker attention = nn.MultiheadAttention(embed_dim, num_heads) 14997*da0073e9SAndroid Build Coastguard Worker attention.eval() 14998*da0073e9SAndroid Build Coastguard Worker 14999*da0073e9SAndroid Build Coastguard Worker self.mod = torch.jit.trace(attention, 15000*da0073e9SAndroid Build Coastguard Worker (sample_q, sample_kv, sample_kv)) 15001*da0073e9SAndroid Build Coastguard Worker 15002*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 15003*da0073e9SAndroid Build Coastguard Worker def forward(self, q, k, v): 15004*da0073e9SAndroid Build Coastguard Worker return self.mod(q, k, v) 15005*da0073e9SAndroid Build Coastguard Worker 15006*da0073e9SAndroid Build Coastguard Worker embed_dim = 8 15007*da0073e9SAndroid Build Coastguard Worker num_heads = 2 15008*da0073e9SAndroid Build Coastguard Worker sl = 3 15009*da0073e9SAndroid Build Coastguard Worker bs = 2 15010*da0073e9SAndroid Build Coastguard Worker model = MyModule(embed_dim, num_heads).cuda() 15011*da0073e9SAndroid Build Coastguard Worker q = torch.randn(sl, bs, embed_dim, device="cuda") 15012*da0073e9SAndroid Build Coastguard Worker kv = torch.randn(sl, bs, embed_dim, device="cuda") 15013*da0073e9SAndroid Build Coastguard Worker 15014*da0073e9SAndroid Build Coastguard Worker jit_out = model(q, kv, kv)[0] 15015*da0073e9SAndroid Build Coastguard Worker py_out = torch.nn.functional.multi_head_attention_forward(q, kv, kv, 15016*da0073e9SAndroid Build Coastguard Worker embed_dim, num_heads, 15017*da0073e9SAndroid Build Coastguard Worker model.mod.in_proj_weight, 15018*da0073e9SAndroid Build Coastguard Worker model.mod.in_proj_bias, 15019*da0073e9SAndroid Build Coastguard Worker None, None, None, 0.0, 15020*da0073e9SAndroid Build Coastguard Worker model.mod.out_proj.weight, 15021*da0073e9SAndroid Build Coastguard Worker model.mod.out_proj.bias)[0] 15022*da0073e9SAndroid Build Coastguard Worker self.assertEqual(jit_out, py_out, atol=5e-4, rtol=1e-4) 15023*da0073e9SAndroid Build Coastguard Worker 15024*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not RUN_CUDA, "no CUDA") 15025*da0073e9SAndroid Build Coastguard Worker def test_scriptmodule_transformer_cuda(self): 15026*da0073e9SAndroid Build Coastguard Worker 15027*da0073e9SAndroid Build Coastguard Worker class MyModule(torch.jit.ScriptModule): 15028*da0073e9SAndroid Build Coastguard Worker def __init__(self, transformer, sample_q, sample_kv): 15029*da0073e9SAndroid Build Coastguard Worker super().__init__() 15030*da0073e9SAndroid Build Coastguard Worker transformer.eval() 15031*da0073e9SAndroid Build Coastguard Worker 15032*da0073e9SAndroid Build Coastguard Worker self.mod = torch.jit.trace(transformer, 15033*da0073e9SAndroid Build Coastguard Worker (sample_q, sample_kv)) 15034*da0073e9SAndroid Build Coastguard Worker 15035*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 15036*da0073e9SAndroid Build Coastguard Worker def forward(self, q, k): 15037*da0073e9SAndroid Build Coastguard Worker return self.mod(q, k) 15038*da0073e9SAndroid Build Coastguard Worker 15039*da0073e9SAndroid Build Coastguard Worker d_model = 8 15040*da0073e9SAndroid Build Coastguard Worker nhead = 2 15041*da0073e9SAndroid Build Coastguard Worker num_encoder_layers = 2 15042*da0073e9SAndroid Build Coastguard Worker num_decoder_layers = 2 15043*da0073e9SAndroid Build Coastguard Worker dim_feedforward = 16 15044*da0073e9SAndroid Build Coastguard Worker bsz = 2 15045*da0073e9SAndroid Build Coastguard Worker seq_length = 5 15046*da0073e9SAndroid Build Coastguard Worker tgt_length = 3 15047*da0073e9SAndroid Build Coastguard Worker 15048*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 15049*da0073e9SAndroid Build Coastguard Worker src = torch.randn(seq_length, bsz, d_model) 15050*da0073e9SAndroid Build Coastguard Worker tgt = torch.randn(tgt_length, bsz, d_model) 15051*da0073e9SAndroid Build Coastguard Worker transformer = nn.Transformer(d_model, nhead, num_encoder_layers, 15052*da0073e9SAndroid Build Coastguard Worker num_decoder_layers, dim_feedforward, dropout=0.0) 15053*da0073e9SAndroid Build Coastguard Worker model = MyModule(transformer, tgt, src) 15054*da0073e9SAndroid Build Coastguard Worker 15055*da0073e9SAndroid Build Coastguard Worker src = torch.randn(seq_length, bsz, d_model) 15056*da0073e9SAndroid Build Coastguard Worker tgt = torch.randn(tgt_length, bsz, d_model) 15057*da0073e9SAndroid Build Coastguard Worker jit_out = model(tgt, src) 15058*da0073e9SAndroid Build Coastguard Worker py_out = transformer(tgt, src) 15059*da0073e9SAndroid Build Coastguard Worker 15060*da0073e9SAndroid Build Coastguard Worker # print(jit_out/py_out-1) 15061*da0073e9SAndroid Build Coastguard Worker # print(torch.allclose(jit_out, py_out, atol=5e-4, rtol=1e-4)) 15062*da0073e9SAndroid Build Coastguard Worker self.assertEqual(jit_out, py_out, atol=5e-4, rtol=1e-4) 15063*da0073e9SAndroid Build Coastguard Worker 15064*da0073e9SAndroid Build Coastguard Worker def test_list_python_op(self): 15065*da0073e9SAndroid Build Coastguard Worker def python_list_op(lst): 15066*da0073e9SAndroid Build Coastguard Worker # type: (List[Tensor]) -> Tensor 15067*da0073e9SAndroid Build Coastguard Worker return lst[0] 15068*da0073e9SAndroid Build Coastguard Worker 15069*da0073e9SAndroid Build Coastguard Worker def fn(lst): 15070*da0073e9SAndroid Build Coastguard Worker # type: (List[Tensor]) -> Tensor 15071*da0073e9SAndroid Build Coastguard Worker return python_list_op(lst) 15072*da0073e9SAndroid Build Coastguard Worker 15073*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, ([torch.ones(2) + 2, torch.ones(2)],)) 15074*da0073e9SAndroid Build Coastguard Worker 15075*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not RUN_CUDA, "no CUDA") 15076*da0073e9SAndroid Build Coastguard Worker def test_weak_cuda(self): 15077*da0073e9SAndroid Build Coastguard Worker class M(torch.jit.ScriptModule): 15078*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 15079*da0073e9SAndroid Build Coastguard Worker super().__init__() 15080*da0073e9SAndroid Build Coastguard Worker self.lstm = torch.nn.LSTM(5, 5) 15081*da0073e9SAndroid Build Coastguard Worker self.lstm.cuda() 15082*da0073e9SAndroid Build Coastguard Worker 15083*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 15084*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 15085*da0073e9SAndroid Build Coastguard Worker return self.lstm(x) 15086*da0073e9SAndroid Build Coastguard Worker 15087*da0073e9SAndroid Build Coastguard Worker m = M() 15088*da0073e9SAndroid Build Coastguard Worker m.cuda() 15089*da0073e9SAndroid Build Coastguard Worker out = m(torch.ones(5, 5, 5).cuda()) 15090*da0073e9SAndroid Build Coastguard Worker self.assertTrue(out[0].is_cuda) 15091*da0073e9SAndroid Build Coastguard Worker 15092*da0073e9SAndroid Build Coastguard Worker def test_ignore_decorator(self): 15093*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as warns: 15094*da0073e9SAndroid Build Coastguard Worker class M(torch.jit.ScriptModule): 15095*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 15096*da0073e9SAndroid Build Coastguard Worker super().__init__() 15097*da0073e9SAndroid Build Coastguard Worker tensor = torch.zeros(1, requires_grad=False) 15098*da0073e9SAndroid Build Coastguard Worker self.some_state = nn.Buffer(torch.nn.Parameter(tensor)) 15099*da0073e9SAndroid Build Coastguard Worker 15100*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 15101*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 15102*da0073e9SAndroid Build Coastguard Worker self.ignored_code(x) 15103*da0073e9SAndroid Build Coastguard Worker return x 15104*da0073e9SAndroid Build Coastguard Worker 15105*da0073e9SAndroid Build Coastguard Worker @torch.jit.ignore(drop_on_export=True) 15106*da0073e9SAndroid Build Coastguard Worker def ignored_code(self, x): 15107*da0073e9SAndroid Build Coastguard Worker self.some_state = torch.tensor((100,)) 15108*da0073e9SAndroid Build Coastguard Worker 15109*da0073e9SAndroid Build Coastguard Worker FileCheck().check("TorchScript will now drop the function").run(str(warns[0])) 15110*da0073e9SAndroid Build Coastguard Worker 15111*da0073e9SAndroid Build Coastguard Worker # Assert ignored code is run 15112*da0073e9SAndroid Build Coastguard Worker m = M() 15113*da0073e9SAndroid Build Coastguard Worker 15114*da0073e9SAndroid Build Coastguard Worker m2 = self.getExportImportCopy(m) 15115*da0073e9SAndroid Build Coastguard Worker pp = str(m2.forward.code) 15116*da0073e9SAndroid Build Coastguard Worker self.assertNotIn('ignored_code', pp) 15117*da0073e9SAndroid Build Coastguard Worker 15118*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(torch.jit.Error, "annotated to be ignored and cannot be run"): 15119*da0073e9SAndroid Build Coastguard Worker m2.forward(torch.ones(1)) 15120*da0073e9SAndroid Build Coastguard Worker 15121*da0073e9SAndroid Build Coastguard Worker def test_ignored_as_value(self): 15122*da0073e9SAndroid Build Coastguard Worker class Model(nn.Module): 15123*da0073e9SAndroid Build Coastguard Worker @torch.jit.unused 15124*da0073e9SAndroid Build Coastguard Worker def tuple_ignored(self, x): 15125*da0073e9SAndroid Build Coastguard Worker # type: (Tensor) -> Tuple[Tensor, Tensor] 15126*da0073e9SAndroid Build Coastguard Worker return x, x 15127*da0073e9SAndroid Build Coastguard Worker 15128*da0073e9SAndroid Build Coastguard Worker @torch.jit.unused 15129*da0073e9SAndroid Build Coastguard Worker def single_val_ignored(self, x, y): 15130*da0073e9SAndroid Build Coastguard Worker # type: (Tensor, Tensor) -> Tensor 15131*da0073e9SAndroid Build Coastguard Worker return x 15132*da0073e9SAndroid Build Coastguard Worker 15133*da0073e9SAndroid Build Coastguard Worker def forward(self, x, use_ignore_path): 15134*da0073e9SAndroid Build Coastguard Worker # type: (Tensor, bool) -> Tuple[Tensor, Tensor] 15135*da0073e9SAndroid Build Coastguard Worker if 1 == 2: 15136*da0073e9SAndroid Build Coastguard Worker return self.tuple_ignored(x) 15137*da0073e9SAndroid Build Coastguard Worker if use_ignore_path: 15138*da0073e9SAndroid Build Coastguard Worker return self.single_val_ignored(x, x), self.single_val_ignored(x, x) 15139*da0073e9SAndroid Build Coastguard Worker return x, x 15140*da0073e9SAndroid Build Coastguard Worker 15141*da0073e9SAndroid Build Coastguard Worker original = Model() 15142*da0073e9SAndroid Build Coastguard Worker scripted = torch.jit.script(original) 15143*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scripted(torch.tensor(.5), False), (torch.tensor(.5), torch.tensor(.5))) 15144*da0073e9SAndroid Build Coastguard Worker 15145*da0073e9SAndroid Build Coastguard Worker buffer = io.BytesIO() 15146*da0073e9SAndroid Build Coastguard Worker torch.jit.save(scripted, buffer) 15147*da0073e9SAndroid Build Coastguard Worker buffer.seek(0) 15148*da0073e9SAndroid Build Coastguard Worker loaded = torch.jit.load(buffer) 15149*da0073e9SAndroid Build Coastguard Worker 15150*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(torch.jit.Error, "annotated to be ignored and cannot be run"): 15151*da0073e9SAndroid Build Coastguard Worker loaded(torch.tensor(.5), True) 15152*da0073e9SAndroid Build Coastguard Worker 15153*da0073e9SAndroid Build Coastguard Worker def test_module_error(self): 15154*da0073e9SAndroid Build Coastguard Worker class MyModule(torch.nn.Module): 15155*da0073e9SAndroid Build Coastguard Worker def forward(self, foo): 15156*da0073e9SAndroid Build Coastguard Worker return foo 15157*da0073e9SAndroid Build Coastguard Worker 15158*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "cannot be compiled since it inherits from nn.Module"): 15159*da0073e9SAndroid Build Coastguard Worker torch.jit.script(MyModule) 15160*da0073e9SAndroid Build Coastguard Worker 15161*da0073e9SAndroid Build Coastguard Worker def test_view_write(self): 15162*da0073e9SAndroid Build Coastguard Worker def fn(x, y): 15163*da0073e9SAndroid Build Coastguard Worker l = [] 15164*da0073e9SAndroid Build Coastguard Worker l.append(x) 15165*da0073e9SAndroid Build Coastguard Worker x_view = l[0] 15166*da0073e9SAndroid Build Coastguard Worker a = x + x 15167*da0073e9SAndroid Build Coastguard Worker x_view.add_(y) 15168*da0073e9SAndroid Build Coastguard Worker b = x + x 15169*da0073e9SAndroid Build Coastguard Worker return a == b 15170*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, (torch.rand(2, 3), torch.rand(2, 3))) 15171*da0073e9SAndroid Build Coastguard Worker 15172*da0073e9SAndroid Build Coastguard Worker def test_module_attrs(self): 15173*da0073e9SAndroid Build Coastguard Worker class M(torch.jit.ScriptModule): 15174*da0073e9SAndroid Build Coastguard Worker def __init__(self, table): 15175*da0073e9SAndroid Build Coastguard Worker super().__init__() 15176*da0073e9SAndroid Build Coastguard Worker self.table = torch.jit.Attribute(table, Dict[str, torch.Tensor]) 15177*da0073e9SAndroid Build Coastguard Worker self.x = torch.nn.Parameter(torch.tensor([100.0])) 15178*da0073e9SAndroid Build Coastguard Worker 15179*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 15180*da0073e9SAndroid Build Coastguard Worker def forward(self, key): 15181*da0073e9SAndroid Build Coastguard Worker # type: (str) -> Tensor 15182*da0073e9SAndroid Build Coastguard Worker return self.table[key] + self.x 15183*da0073e9SAndroid Build Coastguard Worker 15184*da0073e9SAndroid Build Coastguard Worker with torch._jit_internal._disable_emit_hooks(): 15185*da0073e9SAndroid Build Coastguard Worker # TODO: re-enable module hook when Python printing of attributes is 15186*da0073e9SAndroid Build Coastguard Worker # supported 15187*da0073e9SAndroid Build Coastguard Worker m = M({char : torch.ones(1) + ord(char) - ord("a") for char in "abcdefg"}) 15188*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m("c"), torch.tensor([103.])) 15189*da0073e9SAndroid Build Coastguard Worker 15190*da0073e9SAndroid Build Coastguard Worker def test_module_none_attrs(self): 15191*da0073e9SAndroid Build Coastguard Worker class MyMod(torch.jit.ScriptModule): 15192*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 15193*da0073e9SAndroid Build Coastguard Worker super().__init__() 15194*da0073e9SAndroid Build Coastguard Worker self.optional_value = None 15195*da0073e9SAndroid Build Coastguard Worker 15196*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 15197*da0073e9SAndroid Build Coastguard Worker def forward(self): 15198*da0073e9SAndroid Build Coastguard Worker return self.optional_value 15199*da0073e9SAndroid Build Coastguard Worker 15200*da0073e9SAndroid Build Coastguard Worker graph = MyMod().forward.graph 15201*da0073e9SAndroid Build Coastguard Worker FileCheck().check("prim::GetAttr").run(graph) 15202*da0073e9SAndroid Build Coastguard Worker self.run_pass('peephole', graph) 15203*da0073e9SAndroid Build Coastguard Worker FileCheck().check_not("prim::GetAttr").run(graph) 15204*da0073e9SAndroid Build Coastguard Worker 15205*da0073e9SAndroid Build Coastguard Worker def test_tensor_import_export(self): 15206*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 15207*da0073e9SAndroid Build Coastguard Worker def foo(x): 15208*da0073e9SAndroid Build Coastguard Worker a = torch.tensor(1) 15209*da0073e9SAndroid Build Coastguard Worker b = torch.tensor([1, 2]) 15210*da0073e9SAndroid Build Coastguard Worker c = [a, b] 15211*da0073e9SAndroid Build Coastguard Worker return c 15212*da0073e9SAndroid Build Coastguard Worker 15213*da0073e9SAndroid Build Coastguard Worker self.run_pass('constant_propagation', foo.graph) 15214*da0073e9SAndroid Build Coastguard Worker m = self.createFunctionFromGraph(foo.graph) 15215*da0073e9SAndroid Build Coastguard Worker self.getExportImportCopy(m) 15216*da0073e9SAndroid Build Coastguard Worker 15217*da0073e9SAndroid Build Coastguard Worker def get_pickle_values(self): 15218*da0073e9SAndroid Build Coastguard Worker return (('dict', {"I": "am", "a test": "test"}, Dict[str, str]), 15219*da0073e9SAndroid Build Coastguard Worker ('float', 2.3, float), 15220*da0073e9SAndroid Build Coastguard Worker ('int', 99, int), 15221*da0073e9SAndroid Build Coastguard Worker ('bool', False, bool), 15222*da0073e9SAndroid Build Coastguard Worker ('tuple', (1, 2, 3, 4), Tuple[int, int, int, int]), 15223*da0073e9SAndroid Build Coastguard Worker ('list', [(1, 2), (3, 4)], List[Tuple[int, int]]), 15224*da0073e9SAndroid Build Coastguard Worker ('tensor', torch.randn(2, 2), torch.Tensor), 15225*da0073e9SAndroid Build Coastguard Worker ('int_list', [1, 2, 3, 4], List[int]), 15226*da0073e9SAndroid Build Coastguard Worker ('tensor_list', [torch.ones(2, 2) + i for i in range(4)], List[torch.Tensor]), 15227*da0073e9SAndroid Build Coastguard Worker ('bool_list', [True, True, False, True], List[bool]), 15228*da0073e9SAndroid Build Coastguard Worker ('float_list', [1., 2., 3., 4.], List[float]), 15229*da0073e9SAndroid Build Coastguard Worker ('str_list', ['hello', 'bye'], List[str]), 15230*da0073e9SAndroid Build Coastguard Worker ('none', None, Optional[int]), 15231*da0073e9SAndroid Build Coastguard Worker ('a_device', torch.device('cpu'), torch.device), 15232*da0073e9SAndroid Build Coastguard Worker ('another_device', torch.device('cuda:1'), torch.device)) 15233*da0073e9SAndroid Build Coastguard Worker 15234*da0073e9SAndroid Build Coastguard Worker def test_attribute_serialization(self): 15235*da0073e9SAndroid Build Coastguard Worker tester = self 15236*da0073e9SAndroid Build Coastguard Worker 15237*da0073e9SAndroid Build Coastguard Worker class M(torch.jit.ScriptModule): 15238*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 15239*da0073e9SAndroid Build Coastguard Worker super().__init__() 15240*da0073e9SAndroid Build Coastguard Worker for name, value, the_type in tester.get_pickle_values(): 15241*da0073e9SAndroid Build Coastguard Worker setattr(self, name, torch.jit.Attribute(value, the_type)) 15242*da0073e9SAndroid Build Coastguard Worker 15243*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 15244*da0073e9SAndroid Build Coastguard Worker def forward(self): 15245*da0073e9SAndroid Build Coastguard Worker return (self.dict, self.float, self.int, self.bool, self.tuple, 15246*da0073e9SAndroid Build Coastguard Worker self.list, self.int_list, self.tensor_list, self.bool_list, 15247*da0073e9SAndroid Build Coastguard Worker self.float_list, self.str_list, self.none) 15248*da0073e9SAndroid Build Coastguard Worker 15249*da0073e9SAndroid Build Coastguard Worker m = M() 15250*da0073e9SAndroid Build Coastguard Worker imported_m = self.getExportImportCopy(m) 15251*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m(), imported_m()) 15252*da0073e9SAndroid Build Coastguard Worker 15253*da0073e9SAndroid Build Coastguard Worker def test_string_len(self): 15254*da0073e9SAndroid Build Coastguard Worker def fn(x): 15255*da0073e9SAndroid Build Coastguard Worker # type: (str) -> int 15256*da0073e9SAndroid Build Coastguard Worker return len(x) 15257*da0073e9SAndroid Build Coastguard Worker 15258*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, ("",)) 15259*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, ("h",)) 15260*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, ("hello",)) 15261*da0073e9SAndroid Build Coastguard Worker 15262*da0073e9SAndroid Build Coastguard Worker def test_multiline_optional_future_refinement(self): 15263*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 15264*da0073e9SAndroid Build Coastguard Worker def fun() -> int: 15265*da0073e9SAndroid Build Coastguard Worker future: Optional[ 15266*da0073e9SAndroid Build Coastguard Worker torch.jit.Future[Tuple[torch.Tensor]] 15267*da0073e9SAndroid Build Coastguard Worker ] = None 15268*da0073e9SAndroid Build Coastguard Worker 15269*da0073e9SAndroid Build Coastguard Worker return 1 15270*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fun(), 1) 15271*da0073e9SAndroid Build Coastguard Worker 15272*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: TemporaryFileName support for Windows or Sandcastle") 15273*da0073e9SAndroid Build Coastguard Worker def test_attribute_unpickling(self): 15274*da0073e9SAndroid Build Coastguard Worker tensor = torch.randn(2, 2) 15275*da0073e9SAndroid Build Coastguard Worker tester = self 15276*da0073e9SAndroid Build Coastguard Worker 15277*da0073e9SAndroid Build Coastguard Worker class M(torch.jit.ScriptModule): 15278*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 15279*da0073e9SAndroid Build Coastguard Worker super().__init__() 15280*da0073e9SAndroid Build Coastguard Worker for name, value, the_type in tester.get_pickle_values(): 15281*da0073e9SAndroid Build Coastguard Worker setattr(self, "_" + name, torch.jit.Attribute(value, the_type)) 15282*da0073e9SAndroid Build Coastguard Worker 15283*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 15284*da0073e9SAndroid Build Coastguard Worker def forward(self): 15285*da0073e9SAndroid Build Coastguard Worker return (self._dict, self._float, self._int, self._bool, self._tuple, 15286*da0073e9SAndroid Build Coastguard Worker self._list, self._int_list, self._tensor_list, self._bool_list, 15287*da0073e9SAndroid Build Coastguard Worker self._float_list, self._str_list, self._none) 15288*da0073e9SAndroid Build Coastguard Worker 15289*da0073e9SAndroid Build Coastguard Worker with TemporaryFileName() as fname: 15290*da0073e9SAndroid Build Coastguard Worker M().save(fname) 15291*da0073e9SAndroid Build Coastguard Worker loaded = torch.jit.load(fname) 15292*da0073e9SAndroid Build Coastguard Worker 15293*da0073e9SAndroid Build Coastguard Worker def is_tensor_value(item): 15294*da0073e9SAndroid Build Coastguard Worker if isinstance(item, torch.Tensor): 15295*da0073e9SAndroid Build Coastguard Worker return True 15296*da0073e9SAndroid Build Coastguard Worker if isinstance(item, list): 15297*da0073e9SAndroid Build Coastguard Worker return is_tensor_value(item[0]) 15298*da0073e9SAndroid Build Coastguard Worker return False 15299*da0073e9SAndroid Build Coastguard Worker for name, value, the_type in self.get_pickle_values(): 15300*da0073e9SAndroid Build Coastguard Worker if is_tensor_value(value): 15301*da0073e9SAndroid Build Coastguard Worker continue 15302*da0073e9SAndroid Build Coastguard Worker self.assertEqual(value, getattr(loaded, "_" + name)) 15303*da0073e9SAndroid Build Coastguard Worker 15304*da0073e9SAndroid Build Coastguard Worker 15305*da0073e9SAndroid Build Coastguard Worker def test_submodule_attribute_serialization(self): 15306*da0073e9SAndroid Build Coastguard Worker class S(torch.jit.ScriptModule): 15307*da0073e9SAndroid Build Coastguard Worker def __init__(self, list_data): 15308*da0073e9SAndroid Build Coastguard Worker super().__init__() 15309*da0073e9SAndroid Build Coastguard Worker self.table = torch.jit.Attribute({"I": "am", "a test": "test"}, Dict[str, str]) 15310*da0073e9SAndroid Build Coastguard Worker self.list = torch.jit.Attribute(list_data, List[Tuple[int, int]]) 15311*da0073e9SAndroid Build Coastguard Worker 15312*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 15313*da0073e9SAndroid Build Coastguard Worker def forward(self): 15314*da0073e9SAndroid Build Coastguard Worker return (self.table, self.list) 15315*da0073e9SAndroid Build Coastguard Worker 15316*da0073e9SAndroid Build Coastguard Worker class M(torch.jit.ScriptModule): 15317*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 15318*da0073e9SAndroid Build Coastguard Worker super().__init__() 15319*da0073e9SAndroid Build Coastguard Worker self.table = torch.jit.Attribute({"this": "is", "a different": "dict"}, Dict[str, str]) 15320*da0073e9SAndroid Build Coastguard Worker self.tensor = torch.jit.Attribute(torch.randn(2, 2), torch.Tensor) 15321*da0073e9SAndroid Build Coastguard Worker self.s1 = S([(1, 2)]) 15322*da0073e9SAndroid Build Coastguard Worker self.s2 = S([(4, 5)]) 15323*da0073e9SAndroid Build Coastguard Worker 15324*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 15325*da0073e9SAndroid Build Coastguard Worker def forward(self): 15326*da0073e9SAndroid Build Coastguard Worker return (self.table, self.tensor, self.s1.table, self.s2.list, self.s1.list) 15327*da0073e9SAndroid Build Coastguard Worker 15328*da0073e9SAndroid Build Coastguard Worker m = M() 15329*da0073e9SAndroid Build Coastguard Worker imported_m = self.getExportImportCopy(m) 15330*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m(), imported_m()) 15331*da0073e9SAndroid Build Coastguard Worker 15332*da0073e9SAndroid Build Coastguard Worker def test_serialization_big_ints(self): 15333*da0073e9SAndroid Build Coastguard Worker class M(torch.jit.ScriptModule): 15334*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 15335*da0073e9SAndroid Build Coastguard Worker super().__init__() 15336*da0073e9SAndroid Build Coastguard Worker self.int32_max = torch.jit.Attribute(2**31 - 1, int) 15337*da0073e9SAndroid Build Coastguard Worker self.int32_min = torch.jit.Attribute(-2**31, int) 15338*da0073e9SAndroid Build Coastguard Worker self.uint32_max = torch.jit.Attribute(2**32, int) 15339*da0073e9SAndroid Build Coastguard Worker 15340*da0073e9SAndroid Build Coastguard Worker self.int64_max = torch.jit.Attribute(2**63 - 1, int) 15341*da0073e9SAndroid Build Coastguard Worker self.int64_min = torch.jit.Attribute(-2**63, int) 15342*da0073e9SAndroid Build Coastguard Worker 15343*da0073e9SAndroid Build Coastguard Worker self.tensor = torch.nn.Parameter(torch.ones(2, 2)) 15344*da0073e9SAndroid Build Coastguard Worker 15345*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 15346*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 15347*da0073e9SAndroid Build Coastguard Worker # type: (int) -> (int) 15348*da0073e9SAndroid Build Coastguard Worker return x + (self.int32_max + self.int32_min) + (self.int64_max + self.int64_min) 15349*da0073e9SAndroid Build Coastguard Worker 15350*da0073e9SAndroid Build Coastguard Worker m = M() 15351*da0073e9SAndroid Build Coastguard Worker imported = self.getExportImportCopy(m) 15352*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m(10), imported(10)) 15353*da0073e9SAndroid Build Coastguard Worker 15354*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m.int32_max, imported.int32_max) 15355*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m.int32_min, imported.int32_min) 15356*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m.uint32_max, imported.uint32_max) 15357*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m.int64_max, imported.int64_max) 15358*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m.int64_min, imported.int64_min) 15359*da0073e9SAndroid Build Coastguard Worker 15360*da0073e9SAndroid Build Coastguard Worker def test_script_scope(self): 15361*da0073e9SAndroid Build Coastguard Worker scripted = torch.jit.script(torch.nn.functional.triplet_margin_loss) 15362*da0073e9SAndroid Build Coastguard Worker 15363*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_WINDOWS, "NYI: TemporaryFileName on Windows") 15364*da0073e9SAndroid Build Coastguard Worker def test_serialization_sharing(self): 15365*da0073e9SAndroid Build Coastguard Worker class M(torch.jit.ScriptModule): 15366*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 15367*da0073e9SAndroid Build Coastguard Worker super().__init__() 15368*da0073e9SAndroid Build Coastguard Worker self.list = torch.jit.Attribute([], List[str]) 15369*da0073e9SAndroid Build Coastguard Worker 15370*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 15371*da0073e9SAndroid Build Coastguard Worker def forward(self, key): 15372*da0073e9SAndroid Build Coastguard Worker # type: (str) -> List[str] 15373*da0073e9SAndroid Build Coastguard Worker self.list.append(key) 15374*da0073e9SAndroid Build Coastguard Worker self.list.append(key) 15375*da0073e9SAndroid Build Coastguard Worker self.list.append(key) 15376*da0073e9SAndroid Build Coastguard Worker return self.list 15377*da0073e9SAndroid Build Coastguard Worker 15378*da0073e9SAndroid Build Coastguard Worker # the text of the string should only appear once in the pickling 15379*da0073e9SAndroid Build Coastguard Worker m = M() 15380*da0073e9SAndroid Build Coastguard Worker s1 = "a long string" 15381*da0073e9SAndroid Build Coastguard Worker s2 = "a different, even longer string" 15382*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m(s1), [s1] * 3) 15383*da0073e9SAndroid Build Coastguard Worker self.assertEqual(m(s2), [s1] * 3 + [s2] * 3) 15384*da0073e9SAndroid Build Coastguard Worker with TemporaryFileName() as fname: 15385*da0073e9SAndroid Build Coastguard Worker m.save(fname) 15386*da0073e9SAndroid Build Coastguard Worker archive_name = os.path.basename(os.path.normpath(fname)) 15387*da0073e9SAndroid Build Coastguard Worker archive = zipfile.ZipFile(fname, 'r') 15388*da0073e9SAndroid Build Coastguard Worker pickled_data = archive.read(os.path.join(archive_name, 'data.pkl')) 15389*da0073e9SAndroid Build Coastguard Worker 15390*da0073e9SAndroid Build Coastguard Worker out = io.StringIO() 15391*da0073e9SAndroid Build Coastguard Worker pickletools.dis(pickled_data, out=out) 15392*da0073e9SAndroid Build Coastguard Worker disassembled = out.getvalue() 15393*da0073e9SAndroid Build Coastguard Worker 15394*da0073e9SAndroid Build Coastguard Worker FileCheck().check_count(s1, 1, exactly=True) \ 15395*da0073e9SAndroid Build Coastguard Worker .check_count("BINGET", 2, exactly=True) \ 15396*da0073e9SAndroid Build Coastguard Worker .check_count(s2, 1, exactly=True) \ 15397*da0073e9SAndroid Build Coastguard Worker .check_count("BINGET", 2, exactly=True).run(out.getvalue()) 15398*da0073e9SAndroid Build Coastguard Worker 15399*da0073e9SAndroid Build Coastguard Worker def test_sys_stdout_override(self): 15400*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 15401*da0073e9SAndroid Build Coastguard Worker def foo(): 15402*da0073e9SAndroid Build Coastguard Worker print('foo') 15403*da0073e9SAndroid Build Coastguard Worker 15404*da0073e9SAndroid Build Coastguard Worker class Redirect: 15405*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 15406*da0073e9SAndroid Build Coastguard Worker self.s = '' 15407*da0073e9SAndroid Build Coastguard Worker 15408*da0073e9SAndroid Build Coastguard Worker def write(self, s): 15409*da0073e9SAndroid Build Coastguard Worker self.s += s 15410*da0073e9SAndroid Build Coastguard Worker 15411*da0073e9SAndroid Build Coastguard Worker old_stdout = sys.stdout 15412*da0073e9SAndroid Build Coastguard Worker redirect = Redirect() 15413*da0073e9SAndroid Build Coastguard Worker try: 15414*da0073e9SAndroid Build Coastguard Worker sys.stdout = redirect 15415*da0073e9SAndroid Build Coastguard Worker foo() 15416*da0073e9SAndroid Build Coastguard Worker finally: 15417*da0073e9SAndroid Build Coastguard Worker sys.stdout = old_stdout 15418*da0073e9SAndroid Build Coastguard Worker 15419*da0073e9SAndroid Build Coastguard Worker FileCheck().check('foo').run(redirect.s) 15420*da0073e9SAndroid Build Coastguard Worker 15421*da0073e9SAndroid Build Coastguard Worker def test_dtype_attr(self): 15422*da0073e9SAndroid Build Coastguard Worker class Foo(torch.nn.Module): 15423*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 15424*da0073e9SAndroid Build Coastguard Worker super().__init__() 15425*da0073e9SAndroid Build Coastguard Worker self.dtype = torch.zeros([]).dtype 15426*da0073e9SAndroid Build Coastguard Worker 15427*da0073e9SAndroid Build Coastguard Worker def forward(self): 15428*da0073e9SAndroid Build Coastguard Worker return torch.zeros(3, 4, dtype=self.dtype) 15429*da0073e9SAndroid Build Coastguard Worker 15430*da0073e9SAndroid Build Coastguard Worker f = Foo() 15431*da0073e9SAndroid Build Coastguard Worker torch.jit.script(f) 15432*da0073e9SAndroid Build Coastguard Worker 15433*da0073e9SAndroid Build Coastguard Worker 15434*da0073e9SAndroid Build Coastguard Worker def test_named_buffers_are_iterable(self): 15435*da0073e9SAndroid Build Coastguard Worker class MyMod(torch.nn.Module): 15436*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 15437*da0073e9SAndroid Build Coastguard Worker super().__init__() 15438*da0073e9SAndroid Build Coastguard Worker self.mod = (torch.nn.ReLU()) 15439*da0073e9SAndroid Build Coastguard Worker self.mod2 = (torch.nn.ReLU()) 15440*da0073e9SAndroid Build Coastguard Worker self.mod3 = torch.nn.Sequential(torch.nn.Sequential(torch.nn.ReLU())) 15441*da0073e9SAndroid Build Coastguard Worker self.x = nn.Buffer(torch.zeros(3)) 15442*da0073e9SAndroid Build Coastguard Worker self.y = nn.Buffer(torch.zeros(3)) 15443*da0073e9SAndroid Build Coastguard Worker self.z = torch.zeros(3) 15444*da0073e9SAndroid Build Coastguard Worker 15445*da0073e9SAndroid Build Coastguard Worker def bleh(self): 15446*da0073e9SAndroid Build Coastguard Worker return self.z + 4 15447*da0073e9SAndroid Build Coastguard Worker 15448*da0073e9SAndroid Build Coastguard Worker @torch.jit.export 15449*da0073e9SAndroid Build Coastguard Worker def method(self): 15450*da0073e9SAndroid Build Coastguard Worker names = [""] 15451*da0073e9SAndroid Build Coastguard Worker vals = [] 15452*da0073e9SAndroid Build Coastguard Worker for name, buffer in self.named_buffers(): 15453*da0073e9SAndroid Build Coastguard Worker names.append(name) 15454*da0073e9SAndroid Build Coastguard Worker vals.append(buffer + 2) 15455*da0073e9SAndroid Build Coastguard Worker 15456*da0073e9SAndroid Build Coastguard Worker return names, vals 15457*da0073e9SAndroid Build Coastguard Worker 15458*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 15459*da0073e9SAndroid Build Coastguard Worker return x 15460*da0073e9SAndroid Build Coastguard Worker 15461*da0073e9SAndroid Build Coastguard Worker model = MyMod() 15462*da0073e9SAndroid Build Coastguard Worker x = torch.jit.script(model) 15463*da0073e9SAndroid Build Coastguard Worker z = self.getExportImportCopy(x) 15464*da0073e9SAndroid Build Coastguard Worker 15465*da0073e9SAndroid Build Coastguard Worker self.assertEqual(z.method(), x.method()) 15466*da0073e9SAndroid Build Coastguard Worker self.assertEqual(z.method(), model.method()) 15467*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.method(), model.method()) 15468*da0073e9SAndroid Build Coastguard Worker names = x.method() 15469*da0073e9SAndroid Build Coastguard Worker for name in names: 15470*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual('z', name) 15471*da0073e9SAndroid Build Coastguard Worker 15472*da0073e9SAndroid Build Coastguard Worker 15473*da0073e9SAndroid Build Coastguard Worker def test_static_if_prop(self): 15474*da0073e9SAndroid Build Coastguard Worker class MaybeHasAttr(torch.nn.Module): 15475*da0073e9SAndroid Build Coastguard Worker def __init__(self, add_attr): 15476*da0073e9SAndroid Build Coastguard Worker super().__init__() 15477*da0073e9SAndroid Build Coastguard Worker if add_attr: 15478*da0073e9SAndroid Build Coastguard Worker self.maybe_attr = 1 15479*da0073e9SAndroid Build Coastguard Worker 15480*da0073e9SAndroid Build Coastguard Worker def forward(self): 15481*da0073e9SAndroid Build Coastguard Worker if hasattr(self, "maybe_attr") and True: 15482*da0073e9SAndroid Build Coastguard Worker return self.maybe_attr 15483*da0073e9SAndroid Build Coastguard Worker else: 15484*da0073e9SAndroid Build Coastguard Worker return 0 15485*da0073e9SAndroid Build Coastguard Worker 15486*da0073e9SAndroid Build Coastguard Worker class MaybeHasAttr2(torch.nn.Module): 15487*da0073e9SAndroid Build Coastguard Worker def __init__(self, add_attr): 15488*da0073e9SAndroid Build Coastguard Worker super().__init__() 15489*da0073e9SAndroid Build Coastguard Worker if add_attr: 15490*da0073e9SAndroid Build Coastguard Worker self.maybe_attr = 1 15491*da0073e9SAndroid Build Coastguard Worker 15492*da0073e9SAndroid Build Coastguard Worker def forward(self): 15493*da0073e9SAndroid Build Coastguard Worker if not hasattr(self, "maybe_attr") or False: 15494*da0073e9SAndroid Build Coastguard Worker return 0 15495*da0073e9SAndroid Build Coastguard Worker else: 15496*da0073e9SAndroid Build Coastguard Worker return self.maybe_attr 15497*da0073e9SAndroid Build Coastguard Worker 15498*da0073e9SAndroid Build Coastguard Worker torch.jit.script(MaybeHasAttr(True)) 15499*da0073e9SAndroid Build Coastguard Worker torch.jit.script(MaybeHasAttr(False)) 15500*da0073e9SAndroid Build Coastguard Worker torch.jit.script(MaybeHasAttr2(True)) 15501*da0073e9SAndroid Build Coastguard Worker torch.jit.script(MaybeHasAttr2(False)) 15502*da0073e9SAndroid Build Coastguard Worker 15503*da0073e9SAndroid Build Coastguard Worker class MyMod(torch.nn.Module): 15504*da0073e9SAndroid Build Coastguard Worker def forward(self): 15505*da0073e9SAndroid Build Coastguard Worker if hasattr(self, "foo"): 15506*da0073e9SAndroid Build Coastguard Worker return 1 15507*da0073e9SAndroid Build Coastguard Worker else: 15508*da0073e9SAndroid Build Coastguard Worker return 0 15509*da0073e9SAndroid Build Coastguard Worker 15510*da0073e9SAndroid Build Coastguard Worker @torch.jit.export 15511*da0073e9SAndroid Build Coastguard Worker def fee(self): 15512*da0073e9SAndroid Build Coastguard Worker return 1 15513*da0073e9SAndroid Build Coastguard Worker 15514*da0073e9SAndroid Build Coastguard Worker self.checkModule(MyMod(), ()) 15515*da0073e9SAndroid Build Coastguard Worker 15516*da0073e9SAndroid Build Coastguard Worker class HasAttrMod(torch.nn.Module): 15517*da0073e9SAndroid Build Coastguard Worker __constants__ = ["fee"] 15518*da0073e9SAndroid Build Coastguard Worker 15519*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 15520*da0073e9SAndroid Build Coastguard Worker super().__init__() 15521*da0073e9SAndroid Build Coastguard Worker self.fee = 3 15522*da0073e9SAndroid Build Coastguard Worker 15523*da0073e9SAndroid Build Coastguard Worker def forward(self): 15524*da0073e9SAndroid Build Coastguard Worker a = hasattr(self, "fee") 15525*da0073e9SAndroid Build Coastguard Worker b = hasattr(self, "foo") 15526*da0073e9SAndroid Build Coastguard Worker c = hasattr(self, "hi") 15527*da0073e9SAndroid Build Coastguard Worker d = hasattr(self, "nonexistant") 15528*da0073e9SAndroid Build Coastguard Worker return (a, b, c, d) 15529*da0073e9SAndroid Build Coastguard Worker 15530*da0073e9SAndroid Build Coastguard Worker def foo(self): 15531*da0073e9SAndroid Build Coastguard Worker return 1 15532*da0073e9SAndroid Build Coastguard Worker 15533*da0073e9SAndroid Build Coastguard Worker @torch.jit._overload_method 15534*da0073e9SAndroid Build Coastguard Worker def hi(self, x: Tensor): ... # noqa: E704 15535*da0073e9SAndroid Build Coastguard Worker 15536*da0073e9SAndroid Build Coastguard Worker def hi(self, x): # noqa: F811 15537*da0073e9SAndroid Build Coastguard Worker return 2 15538*da0073e9SAndroid Build Coastguard Worker 15539*da0073e9SAndroid Build Coastguard Worker self.checkModule(HasAttrMod(), ()) 15540*da0073e9SAndroid Build Coastguard Worker 15541*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 15542*da0073e9SAndroid Build Coastguard Worker class FooTest: 15543*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 15544*da0073e9SAndroid Build Coastguard Worker self.x = 1 15545*da0073e9SAndroid Build Coastguard Worker 15546*da0073e9SAndroid Build Coastguard Worker def foo(self, y): 15547*da0073e9SAndroid Build Coastguard Worker return self.x + y 15548*da0073e9SAndroid Build Coastguard Worker 15549*da0073e9SAndroid Build Coastguard Worker def foo(): 15550*da0073e9SAndroid Build Coastguard Worker a = FooTest() 15551*da0073e9SAndroid Build Coastguard Worker val1 = hasattr(a, "foo"), hasattr(a, "x"), hasattr(a, "bla") 15552*da0073e9SAndroid Build Coastguard Worker val2 = hasattr(FooTest, "foo"), hasattr(FooTest, "a") 15553*da0073e9SAndroid Build Coastguard Worker return val1, val2 15554*da0073e9SAndroid Build Coastguard Worker 15555*da0073e9SAndroid Build Coastguard Worker self.assertEqual(foo(), torch.jit.script(foo)()) 15556*da0073e9SAndroid Build Coastguard Worker 15557*da0073e9SAndroid Build Coastguard Worker def _test_pickle_checkpoint(self, device): 15558*da0073e9SAndroid Build Coastguard Worker with TemporaryFileName() as fname: 15559*da0073e9SAndroid Build Coastguard Worker class M(torch.jit.ScriptModule): 15560*da0073e9SAndroid Build Coastguard Worker __constants__ = ['fname'] 15561*da0073e9SAndroid Build Coastguard Worker 15562*da0073e9SAndroid Build Coastguard Worker def __init__(self, tensor): 15563*da0073e9SAndroid Build Coastguard Worker super().__init__() 15564*da0073e9SAndroid Build Coastguard Worker self.fname = fname 15565*da0073e9SAndroid Build Coastguard Worker self.tensor = torch.nn.Parameter(tensor) 15566*da0073e9SAndroid Build Coastguard Worker 15567*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 15568*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 15569*da0073e9SAndroid Build Coastguard Worker y = self.tensor + x 15570*da0073e9SAndroid Build Coastguard Worker torch.save(y, self.fname) 15571*da0073e9SAndroid Build Coastguard Worker return y 15572*da0073e9SAndroid Build Coastguard Worker 15573*da0073e9SAndroid Build Coastguard Worker param = torch.randn(2, 2).to(device) 15574*da0073e9SAndroid Build Coastguard Worker input = torch.randn(2, 2).to(device) 15575*da0073e9SAndroid Build Coastguard Worker m = M(param) 15576*da0073e9SAndroid Build Coastguard Worker m(input) 15577*da0073e9SAndroid Build Coastguard Worker with open(fname, "rb") as handle: 15578*da0073e9SAndroid Build Coastguard Worker loaded_tensor = torch.load(fname) 15579*da0073e9SAndroid Build Coastguard Worker self.assertEqual(loaded_tensor, input + param) 15580*da0073e9SAndroid Build Coastguard Worker 15581*da0073e9SAndroid Build Coastguard Worker def _test_pickle_checkpoint_views(self, device): 15582*da0073e9SAndroid Build Coastguard Worker with TemporaryFileName() as fname: 15583*da0073e9SAndroid Build Coastguard Worker class M(torch.jit.ScriptModule): 15584*da0073e9SAndroid Build Coastguard Worker __constants__ = ['fname'] 15585*da0073e9SAndroid Build Coastguard Worker 15586*da0073e9SAndroid Build Coastguard Worker def __init__(self, tensor): 15587*da0073e9SAndroid Build Coastguard Worker super().__init__() 15588*da0073e9SAndroid Build Coastguard Worker self.fname = fname 15589*da0073e9SAndroid Build Coastguard Worker self.tensor = torch.nn.Parameter(tensor) 15590*da0073e9SAndroid Build Coastguard Worker 15591*da0073e9SAndroid Build Coastguard Worker @torch.jit.script_method 15592*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 15593*da0073e9SAndroid Build Coastguard Worker y = self.tensor + x 15594*da0073e9SAndroid Build Coastguard Worker y_view = y.view(4) 15595*da0073e9SAndroid Build Coastguard Worker torch.save((y, y_view, y), self.fname) 15596*da0073e9SAndroid Build Coastguard Worker return y 15597*da0073e9SAndroid Build Coastguard Worker 15598*da0073e9SAndroid Build Coastguard Worker param = torch.randn(2, 2).to(device) 15599*da0073e9SAndroid Build Coastguard Worker input = torch.randn(2, 2).to(device) 15600*da0073e9SAndroid Build Coastguard Worker m = M(param) 15601*da0073e9SAndroid Build Coastguard Worker m(input) 15602*da0073e9SAndroid Build Coastguard Worker with open(fname, "rb") as handle: 15603*da0073e9SAndroid Build Coastguard Worker loaded_y, loaded_y_view, loaded_y_2 = torch.load(fname) 15604*da0073e9SAndroid Build Coastguard Worker self.assertEqual(loaded_y, input + param) 15605*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 15606*da0073e9SAndroid Build Coastguard Worker loaded_y_view[1] += 20 15607*da0073e9SAndroid Build Coastguard Worker # assert that loaded_y changed as well 15608*da0073e9SAndroid Build Coastguard Worker self.assertEqual(loaded_y.view(4), loaded_y_view) 15609*da0073e9SAndroid Build Coastguard Worker self.assertEqual(loaded_y_2.view(4), loaded_y_view) 15610*da0073e9SAndroid Build Coastguard Worker 15611*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not RUN_CUDA, "no CUDA") 15612*da0073e9SAndroid Build Coastguard Worker def test_pickle_checkpoint_cuda(self): 15613*da0073e9SAndroid Build Coastguard Worker self._test_pickle_checkpoint('cuda') 15614*da0073e9SAndroid Build Coastguard Worker self._test_pickle_checkpoint_views('cuda') 15615*da0073e9SAndroid Build Coastguard Worker 15616*da0073e9SAndroid Build Coastguard Worker def test_pickle_checkpoint(self): 15617*da0073e9SAndroid Build Coastguard Worker self._test_pickle_checkpoint('cpu') 15618*da0073e9SAndroid Build Coastguard Worker self._test_pickle_checkpoint_views('cpu') 15619*da0073e9SAndroid Build Coastguard Worker 15620*da0073e9SAndroid Build Coastguard Worker def test_pickle_checkpoint_tup(self): 15621*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 15622*da0073e9SAndroid Build Coastguard Worker def foo(fname): 15623*da0073e9SAndroid Build Coastguard Worker # type: (str) -> None 15624*da0073e9SAndroid Build Coastguard Worker torch.save((3, 4), fname) 15625*da0073e9SAndroid Build Coastguard Worker with TemporaryFileName() as name: 15626*da0073e9SAndroid Build Coastguard Worker foo(name) 15627*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.load(name), (3, 4)) 15628*da0073e9SAndroid Build Coastguard Worker 15629*da0073e9SAndroid Build Coastguard Worker def test_string_list(self): 15630*da0073e9SAndroid Build Coastguard Worker def fn(string): 15631*da0073e9SAndroid Build Coastguard Worker # type: (str) -> List[str] 15632*da0073e9SAndroid Build Coastguard Worker return list(string) 15633*da0073e9SAndroid Build Coastguard Worker 15634*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, ("abcdefgh",)) 15635*da0073e9SAndroid Build Coastguard Worker 15636*da0073e9SAndroid Build Coastguard Worker def test_unicode_comments(self): 15637*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 15638*da0073e9SAndroid Build Coastguard Worker def test(self, a): 15639*da0073e9SAndroid Build Coastguard Worker # 15640*da0073e9SAndroid Build Coastguard Worker return torch.nn.functional.relu(a) 15641*da0073e9SAndroid Build Coastguard Worker 15642*da0073e9SAndroid Build Coastguard Worker def test_get_set_state_with_tensors(self): 15643*da0073e9SAndroid Build Coastguard Worker class M(torch.nn.Module): 15644*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 15645*da0073e9SAndroid Build Coastguard Worker super().__init__() 15646*da0073e9SAndroid Build Coastguard Worker self.tensor = torch.randn(2, 2) 15647*da0073e9SAndroid Build Coastguard Worker 15648*da0073e9SAndroid Build Coastguard Worker @torch.jit.export 15649*da0073e9SAndroid Build Coastguard Worker def __getstate__(self): 15650*da0073e9SAndroid Build Coastguard Worker return (self.tensor, self.training) 15651*da0073e9SAndroid Build Coastguard Worker 15652*da0073e9SAndroid Build Coastguard Worker @torch.jit.export 15653*da0073e9SAndroid Build Coastguard Worker def __setstate__(self, state): 15654*da0073e9SAndroid Build Coastguard Worker self.tensor = state[0] 15655*da0073e9SAndroid Build Coastguard Worker self.training = state[1] 15656*da0073e9SAndroid Build Coastguard Worker 15657*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 15658*da0073e9SAndroid Build Coastguard Worker return x + self.tensor 15659*da0073e9SAndroid Build Coastguard Worker 15660*da0073e9SAndroid Build Coastguard Worker with TemporaryFileName() as fname: 15661*da0073e9SAndroid Build Coastguard Worker m = torch.jit.script(M()) 15662*da0073e9SAndroid Build Coastguard Worker m.save(fname) 15663*da0073e9SAndroid Build Coastguard Worker loaded = torch.jit.load(fname) 15664*da0073e9SAndroid Build Coastguard Worker self.assertEqual(loaded.tensor, m.tensor) 15665*da0073e9SAndroid Build Coastguard Worker 15666*da0073e9SAndroid Build Coastguard Worker def test_in_for_and_comp_expr(self): 15667*da0073e9SAndroid Build Coastguard Worker def fn(d): 15668*da0073e9SAndroid Build Coastguard Worker # type: (Dict[str, int]) -> List[int] 15669*da0073e9SAndroid Build Coastguard Worker out = [1] 15670*da0073e9SAndroid Build Coastguard Worker for i in range(d["hi"] if "hi" in d else 6): 15671*da0073e9SAndroid Build Coastguard Worker out.append(i) # noqa: PERF402 15672*da0073e9SAndroid Build Coastguard Worker return out 15673*da0073e9SAndroid Build Coastguard Worker 15674*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, ({'hi': 2, 'bye': 3},)) 15675*da0073e9SAndroid Build Coastguard Worker self.checkScript(fn, ({'bye': 3},)) 15676*da0073e9SAndroid Build Coastguard Worker 15677*da0073e9SAndroid Build Coastguard Worker def test_for_else(self): 15678*da0073e9SAndroid Build Coastguard Worker def fn(): 15679*da0073e9SAndroid Build Coastguard Worker c = 0 15680*da0073e9SAndroid Build Coastguard Worker for i in range(4): 15681*da0073e9SAndroid Build Coastguard Worker c += 10 15682*da0073e9SAndroid Build Coastguard Worker else: 15683*da0073e9SAndroid Build Coastguard Worker print("In else block of for...else") 15684*da0073e9SAndroid Build Coastguard Worker 15685*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(torch.jit.frontend.NotSupportedError, "else branches of for loops aren't supported"): 15686*da0073e9SAndroid Build Coastguard Worker torch.jit.script(fn) 15687*da0073e9SAndroid Build Coastguard Worker 15688*da0073e9SAndroid Build Coastguard Worker def test_split(self): 15689*da0073e9SAndroid Build Coastguard Worker def split_two(tensor): 15690*da0073e9SAndroid Build Coastguard Worker a, b, c = torch.split(tensor, 2, dim=1) 15691*da0073e9SAndroid Build Coastguard Worker return a, b, c 15692*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, 6) 15693*da0073e9SAndroid Build Coastguard Worker y = torch.randn(3, 6) 15694*da0073e9SAndroid Build Coastguard Worker self.checkScript(split_two, [(x + y)]) 15695*da0073e9SAndroid Build Coastguard Worker 15696*da0073e9SAndroid Build Coastguard Worker def test_conv_error(self): 15697*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 15698*da0073e9SAndroid Build Coastguard Worker def fn(x, y): 15699*da0073e9SAndroid Build Coastguard Worker return F.conv2d(x, y) 15700*da0073e9SAndroid Build Coastguard Worker 15701*da0073e9SAndroid Build Coastguard Worker try: 15702*da0073e9SAndroid Build Coastguard Worker fn(torch.ones(2, 2), torch.ones(4, 4)) 15703*da0073e9SAndroid Build Coastguard Worker except RuntimeError as e: 15704*da0073e9SAndroid Build Coastguard Worker self.assertFalse('frame' in str(e)) 15705*da0073e9SAndroid Build Coastguard Worker 15706*da0073e9SAndroid Build Coastguard Worker def test_python_op_name(self): 15707*da0073e9SAndroid Build Coastguard Worker import random 15708*da0073e9SAndroid Build Coastguard Worker 15709*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "randint"): 15710*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 15711*da0073e9SAndroid Build Coastguard Worker def fn(): 15712*da0073e9SAndroid Build Coastguard Worker return random.randint() 15713*da0073e9SAndroid Build Coastguard Worker 15714*da0073e9SAndroid Build Coastguard Worker def test_dir(self): 15715*da0073e9SAndroid Build Coastguard Worker class M(torch.jit.ScriptModule): 15716*da0073e9SAndroid Build Coastguard Worker def forward(self, t): 15717*da0073e9SAndroid Build Coastguard Worker return t 15718*da0073e9SAndroid Build Coastguard Worker 15719*da0073e9SAndroid Build Coastguard Worker self.assertTrue('forward' in dir(M())) 15720*da0073e9SAndroid Build Coastguard Worker 15721*da0073e9SAndroid Build Coastguard Worker def test_kwarg_expansion_error(self): 15722*da0073e9SAndroid Build Coastguard Worker @torch.jit.ignore 15723*da0073e9SAndroid Build Coastguard Worker def something_else(h, i): 15724*da0073e9SAndroid Build Coastguard Worker pass 15725*da0073e9SAndroid Build Coastguard Worker 15726*da0073e9SAndroid Build Coastguard Worker def fn(x): 15727*da0073e9SAndroid Build Coastguard Worker something_else(**x) 15728*da0073e9SAndroid Build Coastguard Worker 15729*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(torch.jit.frontend.NotSupportedError, "keyword-arg expansion is not supported"): 15730*da0073e9SAndroid Build Coastguard Worker torch.jit.script(fn) 15731*da0073e9SAndroid Build Coastguard Worker 15732*da0073e9SAndroid Build Coastguard Worker def test_kwargs_error_msg(self): 15733*da0073e9SAndroid Build Coastguard Worker def other(**kwargs): 15734*da0073e9SAndroid Build Coastguard Worker print(kwargs) 15735*da0073e9SAndroid Build Coastguard Worker 15736*da0073e9SAndroid Build Coastguard Worker def fn(): 15737*da0073e9SAndroid Build Coastguard Worker return other() 15738*da0073e9SAndroid Build Coastguard Worker 15739*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(torch.jit.frontend.NotSupportedError, 'variable number'): 15740*da0073e9SAndroid Build Coastguard Worker torch.jit.script(fn) 15741*da0073e9SAndroid Build Coastguard Worker 15742*da0073e9SAndroid Build Coastguard Worker def another_other(*args): 15743*da0073e9SAndroid Build Coastguard Worker print(args) 15744*da0073e9SAndroid Build Coastguard Worker 15745*da0073e9SAndroid Build Coastguard Worker def another_fn(): 15746*da0073e9SAndroid Build Coastguard Worker return another_other() 15747*da0073e9SAndroid Build Coastguard Worker 15748*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(torch.jit.frontend.NotSupportedError, 'variable number'): 15749*da0073e9SAndroid Build Coastguard Worker torch.jit.script(another_fn) 15750*da0073e9SAndroid Build Coastguard Worker 15751*da0073e9SAndroid Build Coastguard Worker def test_inferred_error_msg(self): 15752*da0073e9SAndroid Build Coastguard Worker """ 15753*da0073e9SAndroid Build Coastguard Worker Test that when we get a type mismatch on a function where we inferred 15754*da0073e9SAndroid Build Coastguard Worker the type to be tensor, a good error message is given. 15755*da0073e9SAndroid Build Coastguard Worker """ 15756*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 15757*da0073e9SAndroid Build Coastguard Worker def foo(a): 15758*da0073e9SAndroid Build Coastguard Worker return a 15759*da0073e9SAndroid Build Coastguard Worker 15760*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, (r"Expected a value of type \'Tensor \(inferred\)\'" 15761*da0073e9SAndroid Build Coastguard Worker r"[\S\s]*Inferred \'a\' to be of type \'Tensor\'")): 15762*da0073e9SAndroid Build Coastguard Worker foo("1") 15763*da0073e9SAndroid Build Coastguard Worker 15764*da0073e9SAndroid Build Coastguard Worker def test_type_comments_in_body(self): 15765*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 15766*da0073e9SAndroid Build Coastguard Worker def foo(a, # type: int 15767*da0073e9SAndroid Build Coastguard Worker b, # type: int 15768*da0073e9SAndroid Build Coastguard Worker ): 15769*da0073e9SAndroid Build Coastguard Worker # type: (...) -> int 15770*da0073e9SAndroid Build Coastguard Worker # type: int 15771*da0073e9SAndroid Build Coastguard Worker return a + b 15772*da0073e9SAndroid Build Coastguard Worker 15773*da0073e9SAndroid Build Coastguard Worker class M(torch.nn.Module): 15774*da0073e9SAndroid Build Coastguard Worker def __init__(self, 15775*da0073e9SAndroid Build Coastguard Worker a, # type: int 15776*da0073e9SAndroid Build Coastguard Worker b # type: int 15777*da0073e9SAndroid Build Coastguard Worker ): 15778*da0073e9SAndroid Build Coastguard Worker # type: (...) -> None 15779*da0073e9SAndroid Build Coastguard Worker super().__init__() 15780*da0073e9SAndroid Build Coastguard Worker self.a = a # type: int 15781*da0073e9SAndroid Build Coastguard Worker self.b = b # type: int 15782*da0073e9SAndroid Build Coastguard Worker 15783*da0073e9SAndroid Build Coastguard Worker torch.jit.script(M(2, 3)) 15784*da0073e9SAndroid Build Coastguard Worker 15785*da0073e9SAndroid Build Coastguard Worker def test_input_keyword_in_schema(self): 15786*da0073e9SAndroid Build Coastguard Worker def f(x): 15787*da0073e9SAndroid Build Coastguard Worker return torch.ceil(input=x) 15788*da0073e9SAndroid Build Coastguard Worker 15789*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(10) 15790*da0073e9SAndroid Build Coastguard Worker self.checkScript(f, (inp, )) 15791*da0073e9SAndroid Build Coastguard Worker 15792*da0073e9SAndroid Build Coastguard Worker def test_module_method_reassignment(self): 15793*da0073e9SAndroid Build Coastguard Worker class Foo(torch.nn.Module): 15794*da0073e9SAndroid Build Coastguard Worker def _forward(self, x): 15795*da0073e9SAndroid Build Coastguard Worker return x 15796*da0073e9SAndroid Build Coastguard Worker 15797*da0073e9SAndroid Build Coastguard Worker forward = _forward 15798*da0073e9SAndroid Build Coastguard Worker 15799*da0073e9SAndroid Build Coastguard Worker sm = torch.jit.script(Foo()) 15800*da0073e9SAndroid Build Coastguard Worker input = torch.ones(2, 2) 15801*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input, sm(input)) 15802*da0073e9SAndroid Build Coastguard Worker 15803*da0073e9SAndroid Build Coastguard Worker # Tests the case where a torch.Tensor subclass (like Parameter) is used as 15804*da0073e9SAndroid Build Coastguard Worker # input. 15805*da0073e9SAndroid Build Coastguard Worker def test_script_module_tensor_subclass_argument(self): 15806*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 15807*da0073e9SAndroid Build Coastguard Worker def parameter_script(x: torch.nn.Parameter): 15808*da0073e9SAndroid Build Coastguard Worker return x 15809*da0073e9SAndroid Build Coastguard Worker 15810*da0073e9SAndroid Build Coastguard Worker input = torch.ones(2, 2) 15811*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input, parameter_script(input)) 15812*da0073e9SAndroid Build Coastguard Worker 15813*da0073e9SAndroid Build Coastguard Worker def test_save_load_attr_error(self): 15814*da0073e9SAndroid Build Coastguard Worker class Inner(nn.Module): 15815*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 15816*da0073e9SAndroid Build Coastguard Worker return x 15817*da0073e9SAndroid Build Coastguard Worker 15818*da0073e9SAndroid Build Coastguard Worker class Wrapper(nn.Module): 15819*da0073e9SAndroid Build Coastguard Worker def __init__(self, inner): 15820*da0073e9SAndroid Build Coastguard Worker super().__init__() 15821*da0073e9SAndroid Build Coastguard Worker self.inner = inner 15822*da0073e9SAndroid Build Coastguard Worker 15823*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 15824*da0073e9SAndroid Build Coastguard Worker # this attribute doesn't exist on `Inner` 15825*da0073e9SAndroid Build Coastguard Worker return self.inner.b(x) 15826*da0073e9SAndroid Build Coastguard Worker 15827*da0073e9SAndroid Build Coastguard Worker inner_module = torch.jit.script(Inner()) 15828*da0073e9SAndroid Build Coastguard Worker inner_module = self.getExportImportCopy(inner_module) 15829*da0073e9SAndroid Build Coastguard Worker wrapped = Wrapper(inner_module) 15830*da0073e9SAndroid Build Coastguard Worker # This should properly complain that `self.inner` doesn't have the attribute `b` 15831*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, 'has no attribute'): 15832*da0073e9SAndroid Build Coastguard Worker torch.jit.script(wrapped) 15833*da0073e9SAndroid Build Coastguard Worker 15834*da0073e9SAndroid Build Coastguard Worker def test_rescripting_loaded_modules(self): 15835*da0073e9SAndroid Build Coastguard Worker class InnerSubmod(nn.Module): 15836*da0073e9SAndroid Build Coastguard Worker __constants__ = ['my_constant'] 15837*da0073e9SAndroid Build Coastguard Worker 15838*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 15839*da0073e9SAndroid Build Coastguard Worker super().__init__() 15840*da0073e9SAndroid Build Coastguard Worker self.foo = torch.nn.Buffer(torch.ones(1)) 15841*da0073e9SAndroid Build Coastguard Worker self.register_parameter("bar", torch.nn.Parameter(torch.ones(1))) 15842*da0073e9SAndroid Build Coastguard Worker self.baz = torch.ones(1) 15843*da0073e9SAndroid Build Coastguard Worker self.my_constant = 1 15844*da0073e9SAndroid Build Coastguard Worker 15845*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 15846*da0073e9SAndroid Build Coastguard Worker return x + x 15847*da0073e9SAndroid Build Coastguard Worker 15848*da0073e9SAndroid Build Coastguard Worker class Inner(nn.Module): 15849*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 15850*da0073e9SAndroid Build Coastguard Worker super().__init__() 15851*da0073e9SAndroid Build Coastguard Worker self.submod = InnerSubmod() 15852*da0073e9SAndroid Build Coastguard Worker 15853*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 15854*da0073e9SAndroid Build Coastguard Worker return self.submod(x) 15855*da0073e9SAndroid Build Coastguard Worker 15856*da0073e9SAndroid Build Coastguard Worker class Wrapper(nn.Module): 15857*da0073e9SAndroid Build Coastguard Worker def __init__(self, inner): 15858*da0073e9SAndroid Build Coastguard Worker super().__init__() 15859*da0073e9SAndroid Build Coastguard Worker self.inner = inner 15860*da0073e9SAndroid Build Coastguard Worker 15861*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 15862*da0073e9SAndroid Build Coastguard Worker # access inner elements 15863*da0073e9SAndroid Build Coastguard Worker ret = self.inner.submod(x) + self.inner.submod.foo + self.inner.submod.bar + self.inner.submod.baz 15864*da0073e9SAndroid Build Coastguard Worker ret = ret + self.inner.submod.my_constant 15865*da0073e9SAndroid Build Coastguard Worker return ret 15866*da0073e9SAndroid Build Coastguard Worker 15867*da0073e9SAndroid Build Coastguard Worker inner_module = torch.jit.script(Inner()) 15868*da0073e9SAndroid Build Coastguard Worker wrapped = Wrapper(inner_module) 15869*da0073e9SAndroid Build Coastguard Worker self.checkModule(wrapped, torch.ones(1)) 15870*da0073e9SAndroid Build Coastguard Worker 15871*da0073e9SAndroid Build Coastguard Worker inner_module_loaded = self.getExportImportCopy(inner_module) 15872*da0073e9SAndroid Build Coastguard Worker wrapped_loaded = Wrapper(inner_module_loaded) 15873*da0073e9SAndroid Build Coastguard Worker self.assertEqual(wrapped(torch.ones(1)), wrapped_loaded(torch.ones(1))) 15874*da0073e9SAndroid Build Coastguard Worker 15875*da0073e9SAndroid Build Coastguard Worker def test_interpret_graph(self): 15876*da0073e9SAndroid Build Coastguard Worker def fn(x): 15877*da0073e9SAndroid Build Coastguard Worker return x.unfold(0, 1, 1) 15878*da0073e9SAndroid Build Coastguard Worker 15879*da0073e9SAndroid Build Coastguard Worker graph_str = """ 15880*da0073e9SAndroid Build Coastguard Worker graph(%a : Tensor, %b : Tensor): 15881*da0073e9SAndroid Build Coastguard Worker %c : Tensor = aten::mul(%a, %b) 15882*da0073e9SAndroid Build Coastguard Worker return (%c) 15883*da0073e9SAndroid Build Coastguard Worker """ 15884*da0073e9SAndroid Build Coastguard Worker graph = parse_ir(graph_str) 15885*da0073e9SAndroid Build Coastguard Worker a = torch.rand(10) 15886*da0073e9SAndroid Build Coastguard Worker b = torch.rand(10) 15887*da0073e9SAndroid Build Coastguard Worker test = torch._C._jit_interpret_graph(graph, (a, b)) 15888*da0073e9SAndroid Build Coastguard Worker ref = a * b 15889*da0073e9SAndroid Build Coastguard Worker self.assertEqual(test, ref) 15890*da0073e9SAndroid Build Coastguard Worker 15891*da0073e9SAndroid Build Coastguard Worker def test_signed_float_zero(self): 15892*da0073e9SAndroid Build Coastguard Worker 15893*da0073e9SAndroid Build Coastguard Worker class MyModule(torch.nn.Module): 15894*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 15895*da0073e9SAndroid Build Coastguard Worker return torch.div(x, -0.) 15896*da0073e9SAndroid Build Coastguard Worker 15897*da0073e9SAndroid Build Coastguard Worker inp = torch.ones(1) 15898*da0073e9SAndroid Build Coastguard Worker self.checkModule(MyModule(), inp) 15899*da0073e9SAndroid Build Coastguard Worker 15900*da0073e9SAndroid Build Coastguard Worker def test_index_with_tuple(self): 15901*da0073e9SAndroid Build Coastguard Worker class MyModule(torch.nn.Module): 15902*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 15903*da0073e9SAndroid Build Coastguard Worker return x[(1,)] 15904*da0073e9SAndroid Build Coastguard Worker 15905*da0073e9SAndroid Build Coastguard Worker self.checkModule(MyModule(), (torch.ones(2, 3),)) 15906*da0073e9SAndroid Build Coastguard Worker 15907*da0073e9SAndroid Build Coastguard Worker def test_context_manager(self): 15908*da0073e9SAndroid Build Coastguard Worker class MyModule(torch.nn.Module): 15909*da0073e9SAndroid Build Coastguard Worker def forward(self, x, y): 15910*da0073e9SAndroid Build Coastguard Worker p = x + y 15911*da0073e9SAndroid Build Coastguard Worker q = p + 2.0 15912*da0073e9SAndroid Build Coastguard Worker return q 15913*da0073e9SAndroid Build Coastguard Worker 15914*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, 2, dtype=torch.float) 15915*da0073e9SAndroid Build Coastguard Worker y = torch.randn(3, 2, dtype=torch.float) 15916*da0073e9SAndroid Build Coastguard Worker for fuser_name in ['fuser0', 'fuser1', 'none']: 15917*da0073e9SAndroid Build Coastguard Worker with torch.jit.fuser(fuser_name): 15918*da0073e9SAndroid Build Coastguard Worker self.checkModule(MyModule(), (x, y)) 15919*da0073e9SAndroid Build Coastguard Worker 15920*da0073e9SAndroid Build Coastguard Worker def test_zero_dimension_tensor_trace(self): 15921*da0073e9SAndroid Build Coastguard Worker def f(x): 15922*da0073e9SAndroid Build Coastguard Worker return x[x > 0] 15923*da0073e9SAndroid Build Coastguard Worker jf = torch.jit.trace(f, torch.tensor(2., device="cpu")) 15924*da0073e9SAndroid Build Coastguard Worker 15925*da0073e9SAndroid Build Coastguard Worker# known to be failing in tracer 15926*da0073e9SAndroid Build Coastguard WorkerEXCLUDE_TRACED = { 15927*da0073e9SAndroid Build Coastguard Worker # The following fail due to #12024. 15928*da0073e9SAndroid Build Coastguard Worker # A prim::ListConstruct is involved and the indices get traced as TensorType, 15929*da0073e9SAndroid Build Coastguard Worker # which always require_grad. This causes a crash in autodiff. 15930*da0073e9SAndroid Build Coastguard Worker 'test___getitem___adv_index', 15931*da0073e9SAndroid Build Coastguard Worker 'test___getitem___adv_index_beg', 15932*da0073e9SAndroid Build Coastguard Worker 'test___getitem___adv_index_comb', 15933*da0073e9SAndroid Build Coastguard Worker 'test___getitem___adv_index_dup', 15934*da0073e9SAndroid Build Coastguard Worker 'test___getitem___adv_index_sub', 15935*da0073e9SAndroid Build Coastguard Worker 'test___getitem___adv_index_sub_2', 15936*da0073e9SAndroid Build Coastguard Worker 'test___getitem___adv_index_sub_3', 15937*da0073e9SAndroid Build Coastguard Worker 'test___getitem___adv_index_var', 15938*da0073e9SAndroid Build Coastguard Worker 15939*da0073e9SAndroid Build Coastguard Worker # jit doesn't support sparse tensors. 15940*da0073e9SAndroid Build Coastguard Worker 'test_to_sparse', 15941*da0073e9SAndroid Build Coastguard Worker 'test_to_sparse_dim', 15942*da0073e9SAndroid Build Coastguard Worker} 15943*da0073e9SAndroid Build Coastguard Worker 15944*da0073e9SAndroid Build Coastguard WorkerEXCLUDE_TYPE_CHECK = { 15945*da0073e9SAndroid Build Coastguard Worker # slogdet tests use itemgetter to select its only differentiable output, 15946*da0073e9SAndroid Build Coastguard Worker # but this happens outside of the graph we handle, so there are fewer 15947*da0073e9SAndroid Build Coastguard Worker # reference outputs than graph outputs. 15948*da0073e9SAndroid Build Coastguard Worker 'test_slogdet_1x1_neg_det', 15949*da0073e9SAndroid Build Coastguard Worker 'test_slogdet_1x1_pos_det', 15950*da0073e9SAndroid Build Coastguard Worker 'test_slogdet_distinct_singular_values', 15951*da0073e9SAndroid Build Coastguard Worker 'test_slogdet_neg_det', 15952*da0073e9SAndroid Build Coastguard Worker 'test_slogdet_pos_det', 15953*da0073e9SAndroid Build Coastguard Worker 'test_slogdet_symmetric', 15954*da0073e9SAndroid Build Coastguard Worker 'test_slogdet_symmetric_pd', 15955*da0073e9SAndroid Build Coastguard Worker 'test_slogdet_batched_1x1_neg_det', 15956*da0073e9SAndroid Build Coastguard Worker 'test_slogdet_batched_pos_det', 15957*da0073e9SAndroid Build Coastguard Worker 'test_slogdet_batched_symmetric', 15958*da0073e9SAndroid Build Coastguard Worker 'test_slogdet_batched_symmetric_pd', 15959*da0073e9SAndroid Build Coastguard Worker 'test_slogdet_batched_distinct_singular_values' 15960*da0073e9SAndroid Build Coastguard Worker} 15961*da0073e9SAndroid Build Coastguard Worker 15962*da0073e9SAndroid Build Coastguard Worker# chunk returns a list in scripting and we don't unpack the list, 15963*da0073e9SAndroid Build Coastguard Worker# Thus it won't be replaced by ConstantChunk and run AD. 15964*da0073e9SAndroid Build Coastguard Worker# It's explicitly checked in test_chunk_constant_script_ad 15965*da0073e9SAndroid Build Coastguard Worker# Similary for split, it's replaced by split_with_sizes in tracing, 15966*da0073e9SAndroid Build Coastguard Worker# but we don't have AD formula for aten::split(Tensor, int[], int), 15967*da0073e9SAndroid Build Coastguard Worker# an op registered in JIT so AD is not triggered in scripting. 15968*da0073e9SAndroid Build Coastguard WorkerEXCLUDE_SCRIPT_AD_CHECK = { 15969*da0073e9SAndroid Build Coastguard Worker 'test_chunk', 15970*da0073e9SAndroid Build Coastguard Worker 'test_chunk_dim', 15971*da0073e9SAndroid Build Coastguard Worker 'test_chunk_dim_neg0', 15972*da0073e9SAndroid Build Coastguard Worker 'test_split_size_list', 15973*da0073e9SAndroid Build Coastguard Worker 'test_split_size_list_dim', 15974*da0073e9SAndroid Build Coastguard Worker 'test_split_size_list_dim_neg0', 15975*da0073e9SAndroid Build Coastguard Worker 'test_tensor_indices_sections', 15976*da0073e9SAndroid Build Coastguard Worker 'test_tensor_indices_sections_dim', 15977*da0073e9SAndroid Build Coastguard Worker 'test_tensor_indices_sections_dim_neg0', 15978*da0073e9SAndroid Build Coastguard Worker 'test_tensor_split_sections', 15979*da0073e9SAndroid Build Coastguard Worker 'test_tensor_split_sections_dim', 15980*da0073e9SAndroid Build Coastguard Worker 'test_tensor_split_sections_dim_neg0' 15981*da0073e9SAndroid Build Coastguard Worker} 15982*da0073e9SAndroid Build Coastguard Worker 15983*da0073e9SAndroid Build Coastguard WorkerEXCLUDE_PYTHON_PRINT = { 15984*da0073e9SAndroid Build Coastguard Worker # no support for BroadcastingList in python printer 15985*da0073e9SAndroid Build Coastguard Worker 'test_nn_max_unpool1d', 15986*da0073e9SAndroid Build Coastguard Worker 'test_nn_max_unpool2d', 15987*da0073e9SAndroid Build Coastguard Worker 'test_nn_max_unpool3d', 15988*da0073e9SAndroid Build Coastguard Worker 'test_nn_max_pool1d', 15989*da0073e9SAndroid Build Coastguard Worker 'test_nn_max_pool2d', 15990*da0073e9SAndroid Build Coastguard Worker 'test_nn_max_pool3d', 15991*da0073e9SAndroid Build Coastguard Worker 'test_nn_max_pool1d_with_indices', 15992*da0073e9SAndroid Build Coastguard Worker} 15993*da0073e9SAndroid Build Coastguard Worker 15994*da0073e9SAndroid Build Coastguard WorkerEXCLUDE_ALIAS = { 15995*da0073e9SAndroid Build Coastguard Worker # aliases, which may appear in method_tests but are tested elsewhere 15996*da0073e9SAndroid Build Coastguard Worker 'true_divide', 15997*da0073e9SAndroid Build Coastguard Worker 15998*da0073e9SAndroid Build Coastguard Worker # Disable tests for lu from common_methods_invocations.py 15999*da0073e9SAndroid Build Coastguard Worker # TODO(@nikitaved) Enable jit tests once autograd.Function does support scripting 16000*da0073e9SAndroid Build Coastguard Worker 'lu' 16001*da0073e9SAndroid Build Coastguard Worker} 16002*da0073e9SAndroid Build Coastguard Worker 16003*da0073e9SAndroid Build Coastguard Worker 16004*da0073e9SAndroid Build Coastguard Workerclass TestJitGeneratedModule(JitTestCase): 16005*da0073e9SAndroid Build Coastguard Worker pass 16006*da0073e9SAndroid Build Coastguard Worker 16007*da0073e9SAndroid Build Coastguard Worker 16008*da0073e9SAndroid Build Coastguard Workerclass TestJitGeneratedFunctional(JitTestCase): 16009*da0073e9SAndroid Build Coastguard Worker pass 16010*da0073e9SAndroid Build Coastguard Worker 16011*da0073e9SAndroid Build Coastguard Worker# UBSAN per-function exclusions don't seem to work with OpenMP pragmas, 16012*da0073e9SAndroid Build Coastguard Worker# and we have to disable the failing tests here instead. 16013*da0073e9SAndroid Build Coastguard WorkerUBSAN_DISABLED_TESTS = [ 16014*da0073e9SAndroid Build Coastguard Worker "test___rdiv___constant", 16015*da0073e9SAndroid Build Coastguard Worker "test___rdiv___scalar_constant", 16016*da0073e9SAndroid Build Coastguard Worker "test_addcdiv", 16017*da0073e9SAndroid Build Coastguard Worker "test_addcdiv_broadcast_all", 16018*da0073e9SAndroid Build Coastguard Worker "test_addcdiv_broadcast_rhs", 16019*da0073e9SAndroid Build Coastguard Worker "test_addcdiv_scalar", 16020*da0073e9SAndroid Build Coastguard Worker "test_addcdiv_scalar_broadcast_lhs", 16021*da0073e9SAndroid Build Coastguard Worker "test_addcdiv_scalar_broadcast_rhs", 16022*da0073e9SAndroid Build Coastguard Worker "test_addcdiv_scalar_scale", 16023*da0073e9SAndroid Build Coastguard Worker "test_addcdiv_scalar_scale_broadcast_lhs", 16024*da0073e9SAndroid Build Coastguard Worker "test_addcdiv_scalar_scale_broadcast_rhs", 16025*da0073e9SAndroid Build Coastguard Worker "test_addcdiv_scale", 16026*da0073e9SAndroid Build Coastguard Worker "test_addcdiv_scale_broadcast_all", 16027*da0073e9SAndroid Build Coastguard Worker "test_addcdiv_scale_broadcast_rhs", 16028*da0073e9SAndroid Build Coastguard Worker "test_add_broadcast_all", 16029*da0073e9SAndroid Build Coastguard Worker "test_add_broadcast_lhs", 16030*da0073e9SAndroid Build Coastguard Worker "test_add_broadcast_rhs", 16031*da0073e9SAndroid Build Coastguard Worker "test_add_constant", 16032*da0073e9SAndroid Build Coastguard Worker "test_add_scalar", 16033*da0073e9SAndroid Build Coastguard Worker "test_add_scalar_broadcast_lhs", 16034*da0073e9SAndroid Build Coastguard Worker "test_add_scalar_broadcast_rhs", 16035*da0073e9SAndroid Build Coastguard Worker "test_div", 16036*da0073e9SAndroid Build Coastguard Worker "test_div_broadcast_all", 16037*da0073e9SAndroid Build Coastguard Worker "test_div_broadcast_lhs", 16038*da0073e9SAndroid Build Coastguard Worker "test_div_broadcast_rhs", 16039*da0073e9SAndroid Build Coastguard Worker "test_div_scalar", 16040*da0073e9SAndroid Build Coastguard Worker "test_div_scalar_broadcast_lhs", 16041*da0073e9SAndroid Build Coastguard Worker "test_div_scalar_broadcast_rhs", 16042*da0073e9SAndroid Build Coastguard Worker "test_rsqrt", 16043*da0073e9SAndroid Build Coastguard Worker "test_rsqrt_scalar", 16044*da0073e9SAndroid Build Coastguard Worker "test_add", 16045*da0073e9SAndroid Build Coastguard Worker "test_reciprocal", 16046*da0073e9SAndroid Build Coastguard Worker "test_reciprocal_scalar", 16047*da0073e9SAndroid Build Coastguard Worker] 16048*da0073e9SAndroid Build Coastguard Worker 16049*da0073e9SAndroid Build Coastguard WorkerL = 20 16050*da0073e9SAndroid Build Coastguard WorkerM = 10 16051*da0073e9SAndroid Build Coastguard WorkerS = 5 16052*da0073e9SAndroid Build Coastguard Worker 16053*da0073e9SAndroid Build Coastguard Workerdef add_nn_module_test(*args, **kwargs): 16054*da0073e9SAndroid Build Coastguard Worker no_grad = False if 'no_grad' not in kwargs else kwargs['no_grad'] 16055*da0073e9SAndroid Build Coastguard Worker 16056*da0073e9SAndroid Build Coastguard Worker if 'desc' in kwargs and 'eval' in kwargs['desc']: 16057*da0073e9SAndroid Build Coastguard Worker # eval() is not supported, so skip these tests 16058*da0073e9SAndroid Build Coastguard Worker return 16059*da0073e9SAndroid Build Coastguard Worker 16060*da0073e9SAndroid Build Coastguard Worker test_name = get_nn_mod_test_name(**kwargs) 16061*da0073e9SAndroid Build Coastguard Worker 16062*da0073e9SAndroid Build Coastguard Worker @suppress_warnings 16063*da0073e9SAndroid Build Coastguard Worker def do_test(self): 16064*da0073e9SAndroid Build Coastguard Worker if test_name in EXCLUDE_SCRIPT_MODULES: 16065*da0073e9SAndroid Build Coastguard Worker return 16066*da0073e9SAndroid Build Coastguard Worker if not kwargs.get('check_jit', True): 16067*da0073e9SAndroid Build Coastguard Worker raise unittest.SkipTest('module test skipped on JIT') 16068*da0073e9SAndroid Build Coastguard Worker 16069*da0073e9SAndroid Build Coastguard Worker default_dtype = torch.get_default_dtype() 16070*da0073e9SAndroid Build Coastguard Worker if 'default_dtype' in kwargs and kwargs['default_dtype'] is not None: 16071*da0073e9SAndroid Build Coastguard Worker default_dtype = kwargs['default_dtype'] 16072*da0073e9SAndroid Build Coastguard Worker 16073*da0073e9SAndroid Build Coastguard Worker module_name = get_nn_module_name_from_kwargs(**kwargs) 16074*da0073e9SAndroid Build Coastguard Worker 16075*da0073e9SAndroid Build Coastguard Worker if 'constructor' in kwargs: 16076*da0073e9SAndroid Build Coastguard Worker nn_module = kwargs['constructor'] 16077*da0073e9SAndroid Build Coastguard Worker else: 16078*da0073e9SAndroid Build Coastguard Worker nn_module = getattr(torch.nn, module_name) 16079*da0073e9SAndroid Build Coastguard Worker 16080*da0073e9SAndroid Build Coastguard Worker if "FunctionalModule" in str(nn_module): 16081*da0073e9SAndroid Build Coastguard Worker return 16082*da0073e9SAndroid Build Coastguard Worker 16083*da0073e9SAndroid Build Coastguard Worker with set_default_dtype(default_dtype): 16084*da0073e9SAndroid Build Coastguard Worker if 'constructor_args_fn' in kwargs: 16085*da0073e9SAndroid Build Coastguard Worker constructor_args = kwargs['constructor_args_fn']() 16086*da0073e9SAndroid Build Coastguard Worker else: 16087*da0073e9SAndroid Build Coastguard Worker constructor_args = kwargs.get('constructor_args', ()) 16088*da0073e9SAndroid Build Coastguard Worker 16089*da0073e9SAndroid Build Coastguard Worker def create_script_module(*args, **kwargs): 16090*da0073e9SAndroid Build Coastguard Worker """Construct a script module that passes arguments through to self.submodule""" 16091*da0073e9SAndroid Build Coastguard Worker formals, tensors, actuals = get_script_args(args) 16092*da0073e9SAndroid Build Coastguard Worker 16093*da0073e9SAndroid Build Coastguard Worker method_args = ', '.join(['self'] + actuals) 16094*da0073e9SAndroid Build Coastguard Worker call_args_str = ', '.join(actuals) 16095*da0073e9SAndroid Build Coastguard Worker call = f"self.submodule({call_args_str})" 16096*da0073e9SAndroid Build Coastguard Worker script = script_method_template.format(method_args, call) 16097*da0073e9SAndroid Build Coastguard Worker 16098*da0073e9SAndroid Build Coastguard Worker submodule_constants = [] 16099*da0073e9SAndroid Build Coastguard Worker if kwargs.get('is_constant'): 16100*da0073e9SAndroid Build Coastguard Worker submodule_constants = ['submodule'] 16101*da0073e9SAndroid Build Coastguard Worker 16102*da0073e9SAndroid Build Coastguard Worker # Create module to use the script method 16103*da0073e9SAndroid Build Coastguard Worker class TheModule(torch.jit.ScriptModule): 16104*da0073e9SAndroid Build Coastguard Worker __constants__ = submodule_constants 16105*da0073e9SAndroid Build Coastguard Worker 16106*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 16107*da0073e9SAndroid Build Coastguard Worker super().__init__() 16108*da0073e9SAndroid Build Coastguard Worker self.submodule = nn_module(*constructor_args) 16109*da0073e9SAndroid Build Coastguard Worker 16110*da0073e9SAndroid Build Coastguard Worker def make_module(script): 16111*da0073e9SAndroid Build Coastguard Worker module = TheModule() 16112*da0073e9SAndroid Build Coastguard Worker # check __repr__ 16113*da0073e9SAndroid Build Coastguard Worker str(module) 16114*da0073e9SAndroid Build Coastguard Worker module.define(script) 16115*da0073e9SAndroid Build Coastguard Worker return module 16116*da0073e9SAndroid Build Coastguard Worker 16117*da0073e9SAndroid Build Coastguard Worker module = make_module(script) 16118*da0073e9SAndroid Build Coastguard Worker self.assertExportImportModule(module, tensors) 16119*da0073e9SAndroid Build Coastguard Worker create_script_module.last_graph = module.graph 16120*da0073e9SAndroid Build Coastguard Worker mod = module(*args) 16121*da0073e9SAndroid Build Coastguard Worker return mod 16122*da0073e9SAndroid Build Coastguard Worker 16123*da0073e9SAndroid Build Coastguard Worker # Construct a normal nn module to stay consistent with create_script_module 16124*da0073e9SAndroid Build Coastguard Worker # and make use of a single global rng_state in module initialization 16125*da0073e9SAndroid Build Coastguard Worker def create_nn_module(*args, **kwargs): 16126*da0073e9SAndroid Build Coastguard Worker module = nn_module(*constructor_args) 16127*da0073e9SAndroid Build Coastguard Worker return module(*args) 16128*da0073e9SAndroid Build Coastguard Worker 16129*da0073e9SAndroid Build Coastguard Worker # Set up inputs from tuple of sizes or constructor fn 16130*da0073e9SAndroid Build Coastguard Worker dtype = torch.get_default_dtype() 16131*da0073e9SAndroid Build Coastguard Worker if 'input_fn' in kwargs: 16132*da0073e9SAndroid Build Coastguard Worker input = kwargs['input_fn']() 16133*da0073e9SAndroid Build Coastguard Worker if isinstance(input, Tensor): 16134*da0073e9SAndroid Build Coastguard Worker input = (input,) 16135*da0073e9SAndroid Build Coastguard Worker 16136*da0073e9SAndroid Build Coastguard Worker if all(tensor.is_complex() for tensor in input): 16137*da0073e9SAndroid Build Coastguard Worker if dtype == torch.float: 16138*da0073e9SAndroid Build Coastguard Worker dtype = torch.cfloat 16139*da0073e9SAndroid Build Coastguard Worker elif dtype == torch.double: 16140*da0073e9SAndroid Build Coastguard Worker dtype = torch.cdouble 16141*da0073e9SAndroid Build Coastguard Worker else: 16142*da0073e9SAndroid Build Coastguard Worker raise AssertionError(f"default_dtype {default_dtype} is not supported") 16143*da0073e9SAndroid Build Coastguard Worker 16144*da0073e9SAndroid Build Coastguard Worker else: 16145*da0073e9SAndroid Build Coastguard Worker input = (kwargs['input_size'],) 16146*da0073e9SAndroid Build Coastguard Worker 16147*da0073e9SAndroid Build Coastguard Worker if 'target_size' in kwargs: 16148*da0073e9SAndroid Build Coastguard Worker input = input + (kwargs['target_size'],) 16149*da0073e9SAndroid Build Coastguard Worker elif 'target_fn' in kwargs: 16150*da0073e9SAndroid Build Coastguard Worker if torch.is_tensor(input): 16151*da0073e9SAndroid Build Coastguard Worker input = (input,) 16152*da0073e9SAndroid Build Coastguard Worker input = input + (kwargs['target_fn'](),) 16153*da0073e9SAndroid Build Coastguard Worker elif 'target' in kwargs: 16154*da0073e9SAndroid Build Coastguard Worker input = input + (kwargs['target'],) 16155*da0073e9SAndroid Build Coastguard Worker 16156*da0073e9SAndroid Build Coastguard Worker # Extra parameters to forward() 16157*da0073e9SAndroid Build Coastguard Worker if 'extra_args' in kwargs: 16158*da0073e9SAndroid Build Coastguard Worker input = input + kwargs['extra_args'] 16159*da0073e9SAndroid Build Coastguard Worker 16160*da0073e9SAndroid Build Coastguard Worker args_variable, kwargs_variable = create_input(input, dtype=dtype) 16161*da0073e9SAndroid Build Coastguard Worker f_args_variable = deepcopy(unpack_variables(args_variable)) 16162*da0073e9SAndroid Build Coastguard Worker 16163*da0073e9SAndroid Build Coastguard Worker # TODO(issue#52052) Neither this nor no_grad should be required 16164*da0073e9SAndroid Build Coastguard Worker # if check_against_reference() is updated to check gradients 16165*da0073e9SAndroid Build Coastguard Worker # w.r.t. weights and then only check w.r.t. inputs if any 16166*da0073e9SAndroid Build Coastguard Worker # inputs require it. 16167*da0073e9SAndroid Build Coastguard Worker any_requires_grad = any(input.requires_grad for input in f_args_variable) 16168*da0073e9SAndroid Build Coastguard Worker 16169*da0073e9SAndroid Build Coastguard Worker # Check against Python module as reference 16170*da0073e9SAndroid Build Coastguard Worker check_against_reference(self, create_script_module, create_nn_module, 16171*da0073e9SAndroid Build Coastguard Worker lambda x: x, f_args_variable, 16172*da0073e9SAndroid Build Coastguard Worker no_grad=no_grad or not any_requires_grad) 16173*da0073e9SAndroid Build Coastguard Worker 16174*da0073e9SAndroid Build Coastguard Worker if 'slowTest' in kwargs: 16175*da0073e9SAndroid Build Coastguard Worker do_test = slowTest(do_test) 16176*da0073e9SAndroid Build Coastguard Worker 16177*da0073e9SAndroid Build Coastguard Worker post_add_test(test_name, (), do_test, TestJitGeneratedModule) 16178*da0073e9SAndroid Build Coastguard Worker 16179*da0073e9SAndroid Build Coastguard Worker 16180*da0073e9SAndroid Build Coastguard Workerdef post_add_test(test_name, skipTestIf, do_test, test_class): 16181*da0073e9SAndroid Build Coastguard Worker assert not hasattr(test_class, test_name), 'Two tests have the same name: ' + test_name 16182*da0073e9SAndroid Build Coastguard Worker 16183*da0073e9SAndroid Build Coastguard Worker for skip in skipTestIf: 16184*da0073e9SAndroid Build Coastguard Worker do_test = skip(do_test) 16185*da0073e9SAndroid Build Coastguard Worker 16186*da0073e9SAndroid Build Coastguard Worker if not (TEST_WITH_UBSAN and test_name in UBSAN_DISABLED_TESTS): 16187*da0073e9SAndroid Build Coastguard Worker setattr(test_class, test_name, do_test) 16188*da0073e9SAndroid Build Coastguard Worker 16189*da0073e9SAndroid Build Coastguard Worker 16190*da0073e9SAndroid Build Coastguard Workerdef normalize_check_ad(check_ad, name): 16191*da0073e9SAndroid Build Coastguard Worker # normalized check_ad is 3-element tuple: (bool, List[str], List[str]) 16192*da0073e9SAndroid Build Coastguard Worker if len(check_ad) == 0: 16193*da0073e9SAndroid Build Coastguard Worker check_ad = [False, ['aten::' + name], []] 16194*da0073e9SAndroid Build Coastguard Worker elif len(check_ad) == 1: 16195*da0073e9SAndroid Build Coastguard Worker check_ad = [check_ad[0], ['aten::' + name], []] 16196*da0073e9SAndroid Build Coastguard Worker elif len(check_ad) == 2: 16197*da0073e9SAndroid Build Coastguard Worker check_ad = [check_ad[0], check_ad[1], []] 16198*da0073e9SAndroid Build Coastguard Worker elif len(check_ad) == 3: 16199*da0073e9SAndroid Build Coastguard Worker check_ad = list(check_ad) 16200*da0073e9SAndroid Build Coastguard Worker else: 16201*da0073e9SAndroid Build Coastguard Worker raise Exception('Invalid check_ad, requires (bool, str|List[str], str|List[str])') # noqa: TRY002 16202*da0073e9SAndroid Build Coastguard Worker 16203*da0073e9SAndroid Build Coastguard Worker check_ad = [[t] if isinstance(t, str) else t for t in check_ad] 16204*da0073e9SAndroid Build Coastguard Worker 16205*da0073e9SAndroid Build Coastguard Worker return check_ad 16206*da0073e9SAndroid Build Coastguard Worker 16207*da0073e9SAndroid Build Coastguard Worker 16208*da0073e9SAndroid Build Coastguard Workerclass TestProducerVersion(TestCase): 16209*da0073e9SAndroid Build Coastguard Worker 16210*da0073e9SAndroid Build Coastguard Worker def test_version(self): 16211*da0073e9SAndroid Build Coastguard Worker # issue gh-32561 16212*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.__version__.startswith(torch.onnx.producer_version)) 16213*da0073e9SAndroid Build Coastguard Worker 16214*da0073e9SAndroid Build Coastguard Workerfor test in module_tests + new_module_tests + additional_module_tests: 16215*da0073e9SAndroid Build Coastguard Worker add_nn_module_test(**test) 16216*da0073e9SAndroid Build Coastguard Worker 16217*da0073e9SAndroid Build Coastguard Workerfor test in criterion_tests: 16218*da0073e9SAndroid Build Coastguard Worker test['no_grad'] = True 16219*da0073e9SAndroid Build Coastguard Worker add_nn_module_test(**test) 16220*da0073e9SAndroid Build Coastguard Worker 16221*da0073e9SAndroid Build Coastguard Workerif __name__ == '__main__': 16222*da0073e9SAndroid Build Coastguard Worker TestCase._default_dtype_check_enabled = True 16223*da0073e9SAndroid Build Coastguard Worker run_tests() 16224*da0073e9SAndroid Build Coastguard Worker import jit.test_module_interface 16225*da0073e9SAndroid Build Coastguard Worker suite = unittest.findTestCases(jit.test_module_interface) 16226*da0073e9SAndroid Build Coastguard Worker unittest.TextTestRunner().run(suite) 16227