1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: vmap"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport functools 4*da0073e9SAndroid Build Coastguard Workerimport itertools 5*da0073e9SAndroid Build Coastguard Workerimport types 6*da0073e9SAndroid Build Coastguard Workerimport warnings 7*da0073e9SAndroid Build Coastguard Worker 8*da0073e9SAndroid Build Coastguard Workerimport torch 9*da0073e9SAndroid Build Coastguard Workerimport torch.nn.functional as F 10*da0073e9SAndroid Build Coastguard Workerfrom torch import Tensor 11*da0073e9SAndroid Build Coastguard Workerfrom torch._vmap_internals import vmap 12*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import instantiate_device_type_tests 13*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, TestCase 14*da0073e9SAndroid Build Coastguard Worker 15*da0073e9SAndroid Build Coastguard Worker 16*da0073e9SAndroid Build Coastguard WorkerFALLBACK_REGEX = r"There is a performance drop" 17*da0073e9SAndroid Build Coastguard Worker 18*da0073e9SAndroid Build Coastguard Worker 19*da0073e9SAndroid Build Coastguard Workerclass EnableVmapFallbackWarnings: 20*da0073e9SAndroid Build Coastguard Worker def __enter__(self): 21*da0073e9SAndroid Build Coastguard Worker self.prev_state = torch._C._debug_only_are_vmap_fallback_warnings_enabled() 22*da0073e9SAndroid Build Coastguard Worker torch._C._debug_only_display_vmap_fallback_warnings(True) 23*da0073e9SAndroid Build Coastguard Worker 24*da0073e9SAndroid Build Coastguard Worker def __exit__(self, *ignored): 25*da0073e9SAndroid Build Coastguard Worker torch._C._debug_only_display_vmap_fallback_warnings(self.prev_state) 26*da0073e9SAndroid Build Coastguard Worker 27*da0073e9SAndroid Build Coastguard Worker 28*da0073e9SAndroid Build Coastguard Workerclass TestVmapAPILegacy(TestCase): 29*da0073e9SAndroid Build Coastguard Worker def test_non_tensor_output_raises(self): 30*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 31*da0073e9SAndroid Build Coastguard Worker ValueError, "got type <class 'float'> as the return" 32*da0073e9SAndroid Build Coastguard Worker ): 33*da0073e9SAndroid Build Coastguard Worker output = vmap(lambda x: 3.14)(torch.ones(3)) 34*da0073e9SAndroid Build Coastguard Worker 35*da0073e9SAndroid Build Coastguard Worker def multiple_outputs(x): 36*da0073e9SAndroid Build Coastguard Worker return x, 3 37*da0073e9SAndroid Build Coastguard Worker 38*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "got type <class 'int'> for return 1"): 39*da0073e9SAndroid Build Coastguard Worker vmap(multiple_outputs)(torch.ones(3)) 40*da0073e9SAndroid Build Coastguard Worker 41*da0073e9SAndroid Build Coastguard Worker def test_different_map_dim_size_raises(self): 42*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2) 43*da0073e9SAndroid Build Coastguard Worker y = torch.randn(3) 44*da0073e9SAndroid Build Coastguard Worker expected_msg = ( 45*da0073e9SAndroid Build Coastguard Worker "Expected all tensors to have the same size in the mapped dimension" 46*da0073e9SAndroid Build Coastguard Worker ) 47*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, expected_msg): 48*da0073e9SAndroid Build Coastguard Worker vmap(torch.mul)(x, y) 49*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, expected_msg): 50*da0073e9SAndroid Build Coastguard Worker vmap(lambda z: z[0] + z[1], in_dims=((0, 0),))((x, y)) 51*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, expected_msg): 52*da0073e9SAndroid Build Coastguard Worker vmap(lambda z: z["x"] + z["y"], in_dims=({"x": 0, "y": 0},))( 53*da0073e9SAndroid Build Coastguard Worker {"x": x, "y": y} 54*da0073e9SAndroid Build Coastguard Worker ) 55*da0073e9SAndroid Build Coastguard Worker 56*da0073e9SAndroid Build Coastguard Worker def test_func_with_no_inputs(self): 57*da0073e9SAndroid Build Coastguard Worker expected_msg = "got no inputs" 58*da0073e9SAndroid Build Coastguard Worker 59*da0073e9SAndroid Build Coastguard Worker def foo(): 60*da0073e9SAndroid Build Coastguard Worker return torch.randn(3) 61*da0073e9SAndroid Build Coastguard Worker 62*da0073e9SAndroid Build Coastguard Worker def bar(x): 63*da0073e9SAndroid Build Coastguard Worker return torch.randn(3) 64*da0073e9SAndroid Build Coastguard Worker 65*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, expected_msg): 66*da0073e9SAndroid Build Coastguard Worker vmap(foo)() 67*da0073e9SAndroid Build Coastguard Worker 68*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, expected_msg): 69*da0073e9SAndroid Build Coastguard Worker vmap(bar)() 70*da0073e9SAndroid Build Coastguard Worker 71*da0073e9SAndroid Build Coastguard Worker def test_constant_function(self): 72*da0073e9SAndroid Build Coastguard Worker output = vmap(lambda x: torch.tensor(3.14))(torch.ones(3)) 73*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output, torch.tensor([3.14, 3.14, 3.14])) 74*da0073e9SAndroid Build Coastguard Worker 75*da0073e9SAndroid Build Coastguard Worker def test_single_input(self): 76*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 3) 77*da0073e9SAndroid Build Coastguard Worker 78*da0073e9SAndroid Build Coastguard Worker def square(x): 79*da0073e9SAndroid Build Coastguard Worker return x * x 80*da0073e9SAndroid Build Coastguard Worker 81*da0073e9SAndroid Build Coastguard Worker output = vmap(square)(x) 82*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output, x * x) 83*da0073e9SAndroid Build Coastguard Worker 84*da0073e9SAndroid Build Coastguard Worker def test_multiple_inputs(self): 85*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 3) 86*da0073e9SAndroid Build Coastguard Worker y = torch.randn(2, 3) 87*da0073e9SAndroid Build Coastguard Worker output = vmap(torch.mul)(x, y) 88*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output, x * y) 89*da0073e9SAndroid Build Coastguard Worker 90*da0073e9SAndroid Build Coastguard Worker def test_multiple_outputs(self): 91*da0073e9SAndroid Build Coastguard Worker def foo(x): 92*da0073e9SAndroid Build Coastguard Worker return x * x, x * x * x 93*da0073e9SAndroid Build Coastguard Worker 94*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3) 95*da0073e9SAndroid Build Coastguard Worker outputs = vmap(foo)(x) 96*da0073e9SAndroid Build Coastguard Worker self.assertEqual(outputs[0], x * x) 97*da0073e9SAndroid Build Coastguard Worker self.assertEqual(outputs[1], x * x * x) 98*da0073e9SAndroid Build Coastguard Worker 99*da0073e9SAndroid Build Coastguard Worker def test_multiple_outputs_error_cases(self): 100*da0073e9SAndroid Build Coastguard Worker # This is the same thing as 101*da0073e9SAndroid Build Coastguard Worker # def returns_tuple_of_tensors(x): 102*da0073e9SAndroid Build Coastguard Worker # return x, x 103*da0073e9SAndroid Build Coastguard Worker def returns_tuple_of_tensors(x): 104*da0073e9SAndroid Build Coastguard Worker return (x, x) 105*da0073e9SAndroid Build Coastguard Worker 106*da0073e9SAndroid Build Coastguard Worker def returns_list_of_two_tensors(x): 107*da0073e9SAndroid Build Coastguard Worker return [x, x] 108*da0073e9SAndroid Build Coastguard Worker 109*da0073e9SAndroid Build Coastguard Worker def returns_list_of_one_tensor(x): 110*da0073e9SAndroid Build Coastguard Worker return [x] 111*da0073e9SAndroid Build Coastguard Worker 112*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3) 113*da0073e9SAndroid Build Coastguard Worker 114*da0073e9SAndroid Build Coastguard Worker # should not throw 115*da0073e9SAndroid Build Coastguard Worker vmap(returns_tuple_of_tensors)(x) 116*da0073e9SAndroid Build Coastguard Worker 117*da0073e9SAndroid Build Coastguard Worker # jax supports these, but we don't yet 118*da0073e9SAndroid Build Coastguard Worker msg = "must only return Tensors, got type <class 'list'>" 119*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, msg): 120*da0073e9SAndroid Build Coastguard Worker vmap(returns_list_of_two_tensors)(x) 121*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, msg): 122*da0073e9SAndroid Build Coastguard Worker vmap(returns_list_of_one_tensor)(x) 123*da0073e9SAndroid Build Coastguard Worker 124*da0073e9SAndroid Build Coastguard Worker def test_nested_with_same_map_dim(self): 125*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 3, 5) 126*da0073e9SAndroid Build Coastguard Worker y = torch.randn(2, 3, 5) 127*da0073e9SAndroid Build Coastguard Worker output = vmap(vmap(torch.mul))(x, y) 128*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output, x * y) 129*da0073e9SAndroid Build Coastguard Worker 130*da0073e9SAndroid Build Coastguard Worker output = vmap(vmap(vmap(torch.mul)))(x, y) 131*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output, x * y) 132*da0073e9SAndroid Build Coastguard Worker 133*da0073e9SAndroid Build Coastguard Worker def test_nested_with_different_map_dim(self): 134*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 3) 135*da0073e9SAndroid Build Coastguard Worker y = torch.randn(5, 3) 136*da0073e9SAndroid Build Coastguard Worker output = vmap(lambda x: vmap(lambda y: x * y)(y))(x) 137*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output.shape, (2, 5, 3)) 138*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output, x.view(2, 1, 3) * y) 139*da0073e9SAndroid Build Coastguard Worker 140*da0073e9SAndroid Build Coastguard Worker z = torch.randn(7, 3) 141*da0073e9SAndroid Build Coastguard Worker output = vmap(lambda x: vmap(lambda y: vmap(lambda z: x * y * z)(z))(y))(x) 142*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output.shape, (2, 5, 7, 3)) 143*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output, x.view(2, 1, 1, 3) * y.view(5, 1, 3) * z) 144*da0073e9SAndroid Build Coastguard Worker 145*da0073e9SAndroid Build Coastguard Worker def test_noop_in_inner_vmap(self): 146*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3) 147*da0073e9SAndroid Build Coastguard Worker y = torch.randn(5) 148*da0073e9SAndroid Build Coastguard Worker output = vmap(lambda x: vmap(lambda y: x)(y))(x) 149*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output, x.view(3, 1).expand(3, 5)) 150*da0073e9SAndroid Build Coastguard Worker 151*da0073e9SAndroid Build Coastguard Worker def test_unsupported_op_err_msg(self): 152*da0073e9SAndroid Build Coastguard Worker # Unsupported view op 153*da0073e9SAndroid Build Coastguard Worker tensor = torch.randn(2, 3) 154*da0073e9SAndroid Build Coastguard Worker msg = ( 155*da0073e9SAndroid Build Coastguard Worker r"Batching rule not implemented for aten::.+; the " 156*da0073e9SAndroid Build Coastguard Worker r"fallback path doesn't work on out= or view ops" 157*da0073e9SAndroid Build Coastguard Worker ) 158*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, msg): 159*da0073e9SAndroid Build Coastguard Worker vmap(torch.ravel)(tensor) 160*da0073e9SAndroid Build Coastguard Worker 161*da0073e9SAndroid Build Coastguard Worker def out_op(x, y): 162*da0073e9SAndroid Build Coastguard Worker return torch.abs(x, out=y) 163*da0073e9SAndroid Build Coastguard Worker 164*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, msg): 165*da0073e9SAndroid Build Coastguard Worker vmap(out_op)(tensor, tensor) 166*da0073e9SAndroid Build Coastguard Worker 167*da0073e9SAndroid Build Coastguard Worker tensor = torch.randn(2) 168*da0073e9SAndroid Build Coastguard Worker # The fallback doesn't support TensorList 169*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Batching rule not implemented"): 170*da0073e9SAndroid Build Coastguard Worker vmap(lambda t: torch.atleast_1d([t]))(tensor) 171*da0073e9SAndroid Build Coastguard Worker 172*da0073e9SAndroid Build Coastguard Worker # Don't support non-tensor returns. This is a limitation of vmap; 173*da0073e9SAndroid Build Coastguard Worker # functions that don't return tensors must be special cased 174*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Batching rule not implemented"): 175*da0073e9SAndroid Build Coastguard Worker vmap(torch.Tensor.item)(tensor) 176*da0073e9SAndroid Build Coastguard Worker 177*da0073e9SAndroid Build Coastguard Worker def test_nonzero_out_dims(self): 178*da0073e9SAndroid Build Coastguard Worker # Basic test 179*da0073e9SAndroid Build Coastguard Worker tensor = torch.randn(2, 3) 180*da0073e9SAndroid Build Coastguard Worker result = vmap(lambda x: x, out_dims=1)(tensor) 181*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, tensor.permute(1, 0)) 182*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.data_ptr(), tensor.data_ptr()) 183*da0073e9SAndroid Build Coastguard Worker 184*da0073e9SAndroid Build Coastguard Worker # Test that the batch dimension gets permuted to dim 2 185*da0073e9SAndroid Build Coastguard Worker tensor = torch.randn(2, 3, 5, 7) 186*da0073e9SAndroid Build Coastguard Worker result = vmap(lambda x: x, out_dims=2)(tensor) 187*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, tensor.permute(1, 2, 0, 3)) 188*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.data_ptr(), tensor.data_ptr()) 189*da0073e9SAndroid Build Coastguard Worker 190*da0073e9SAndroid Build Coastguard Worker # negative out_dim 191*da0073e9SAndroid Build Coastguard Worker tensor = torch.randn(2, 3, 5, 7) 192*da0073e9SAndroid Build Coastguard Worker result = vmap(lambda x: x, out_dims=-1)(tensor) 193*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, tensor.permute(1, 2, 3, 0)) 194*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.data_ptr(), tensor.data_ptr()) 195*da0073e9SAndroid Build Coastguard Worker 196*da0073e9SAndroid Build Coastguard Worker # check that out_dims works on ALL outputs 197*da0073e9SAndroid Build Coastguard Worker tensor = torch.randn(2, 3, 5, 7) 198*da0073e9SAndroid Build Coastguard Worker other = torch.randn(2, 3, 5, 7) 199*da0073e9SAndroid Build Coastguard Worker result = vmap(lambda x, y: (x, y), out_dims=2)(tensor, other) 200*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 201*da0073e9SAndroid Build Coastguard Worker result, (tensor.permute(1, 2, 0, 3), other.permute(1, 2, 0, 3)) 202*da0073e9SAndroid Build Coastguard Worker ) 203*da0073e9SAndroid Build Coastguard Worker 204*da0073e9SAndroid Build Coastguard Worker # use out_dims with the maximum vmap-able tensor dims (64 dims) 205*da0073e9SAndroid Build Coastguard Worker ndims = 64 206*da0073e9SAndroid Build Coastguard Worker shape = [2] + [1] * (ndims - 1) 207*da0073e9SAndroid Build Coastguard Worker expected_shape = [1, 1, 2] + [1] * (ndims - 3) 208*da0073e9SAndroid Build Coastguard Worker tensor = torch.randn(shape) 209*da0073e9SAndroid Build Coastguard Worker result = vmap(lambda x: x, out_dims=2)(tensor) 210*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.shape, expected_shape) 211*da0073e9SAndroid Build Coastguard Worker 212*da0073e9SAndroid Build Coastguard Worker # test something that is not the identity function 213*da0073e9SAndroid Build Coastguard Worker def foo(x, y): 214*da0073e9SAndroid Build Coastguard Worker return x, x * y, x * y * y 215*da0073e9SAndroid Build Coastguard Worker 216*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 3, 5) 217*da0073e9SAndroid Build Coastguard Worker y = torch.randn(2, 3, 5) 218*da0073e9SAndroid Build Coastguard Worker result = vmap(foo, out_dims=1)(x, y) 219*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 220*da0073e9SAndroid Build Coastguard Worker result, 221*da0073e9SAndroid Build Coastguard Worker ( 222*da0073e9SAndroid Build Coastguard Worker x.permute(1, 0, 2), 223*da0073e9SAndroid Build Coastguard Worker (x * y).permute(1, 0, 2), 224*da0073e9SAndroid Build Coastguard Worker (x * y * y).permute(1, 0, 2), 225*da0073e9SAndroid Build Coastguard Worker ), 226*da0073e9SAndroid Build Coastguard Worker ) 227*da0073e9SAndroid Build Coastguard Worker 228*da0073e9SAndroid Build Coastguard Worker def test_multiple_out_dims(self): 229*da0073e9SAndroid Build Coastguard Worker def foo(x): 230*da0073e9SAndroid Build Coastguard Worker return x, x 231*da0073e9SAndroid Build Coastguard Worker 232*da0073e9SAndroid Build Coastguard Worker def bar(x, y): 233*da0073e9SAndroid Build Coastguard Worker return x, x, x, x * y 234*da0073e9SAndroid Build Coastguard Worker 235*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 3, 5) 236*da0073e9SAndroid Build Coastguard Worker y = torch.randn(2, 3, 5) 237*da0073e9SAndroid Build Coastguard Worker result = vmap(foo, out_dims=(0, 1))(x) 238*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, (x, x.permute(1, 0, 2))) 239*da0073e9SAndroid Build Coastguard Worker 240*da0073e9SAndroid Build Coastguard Worker result = vmap(bar, out_dims=(-1, 0, 1, 2))(x, y) 241*da0073e9SAndroid Build Coastguard Worker expected = ( 242*da0073e9SAndroid Build Coastguard Worker x.permute(1, 2, 0), 243*da0073e9SAndroid Build Coastguard Worker x, 244*da0073e9SAndroid Build Coastguard Worker x.permute(1, 0, 2), 245*da0073e9SAndroid Build Coastguard Worker (x * y).permute(1, 2, 0), 246*da0073e9SAndroid Build Coastguard Worker ) 247*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, expected) 248*da0073e9SAndroid Build Coastguard Worker 249*da0073e9SAndroid Build Coastguard Worker def test_nested_out_dims(self): 250*da0073e9SAndroid Build Coastguard Worker y = torch.randn(2, 3, 5, 7) 251*da0073e9SAndroid Build Coastguard Worker 252*da0073e9SAndroid Build Coastguard Worker # Inner vmap has non-zero out_dim 253*da0073e9SAndroid Build Coastguard Worker result = vmap(lambda y: vmap(lambda x: x, out_dims=1)(y))(y) 254*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.shape, (2, 5, 3, 7)) 255*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, y.permute(0, 2, 1, 3)) 256*da0073e9SAndroid Build Coastguard Worker 257*da0073e9SAndroid Build Coastguard Worker # all vmaps have non-zero out_dim 258*da0073e9SAndroid Build Coastguard Worker result = vmap(lambda y: vmap(lambda x: x, out_dims=1)(y), out_dims=1)(y) 259*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.shape, (5, 2, 3, 7)) 260*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, y.permute(2, 0, 1, 3)) 261*da0073e9SAndroid Build Coastguard Worker 262*da0073e9SAndroid Build Coastguard Worker # throwing in some negative out_dims 263*da0073e9SAndroid Build Coastguard Worker result = vmap(lambda y: vmap(lambda x: x, out_dims=-1)(y), out_dims=-1)(y) 264*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.shape, (5, 7, 3, 2)) 265*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, y.permute(2, 3, 1, 0)) 266*da0073e9SAndroid Build Coastguard Worker 267*da0073e9SAndroid Build Coastguard Worker # testing fn that isn't the identity 268*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 3) 269*da0073e9SAndroid Build Coastguard Worker y = torch.randn(5, 3) 270*da0073e9SAndroid Build Coastguard Worker result = vmap(lambda y: vmap(lambda x: x * y, out_dims=1)(x), out_dims=-1)(y) 271*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.shape, (3, 2, 5)) 272*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, (y.view(5, 1, 3) * x).permute(2, 1, 0)) 273*da0073e9SAndroid Build Coastguard Worker 274*da0073e9SAndroid Build Coastguard Worker def test_out_dims_edge_case(self): 275*da0073e9SAndroid Build Coastguard Worker def foo(x): 276*da0073e9SAndroid Build Coastguard Worker return x 277*da0073e9SAndroid Build Coastguard Worker 278*da0073e9SAndroid Build Coastguard Worker # Test that we accept out_dims=(1,) for a function with one output. 279*da0073e9SAndroid Build Coastguard Worker tensor = torch.randn(2, 3) 280*da0073e9SAndroid Build Coastguard Worker expected = vmap(foo, out_dims=1)(tensor) 281*da0073e9SAndroid Build Coastguard Worker result = vmap(foo, out_dims=(1,))(tensor) 282*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, expected) 283*da0073e9SAndroid Build Coastguard Worker 284*da0073e9SAndroid Build Coastguard Worker def test_out_dims_must_be_int_or_tuple_of_int_err_msg(self): 285*da0073e9SAndroid Build Coastguard Worker msg = "`out_dims` must be an int or a tuple of int" 286*da0073e9SAndroid Build Coastguard Worker tensor = torch.randn(2, 3) 287*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, msg): 288*da0073e9SAndroid Build Coastguard Worker vmap(lambda x: x, out_dims="lol")(tensor) 289*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, msg): 290*da0073e9SAndroid Build Coastguard Worker vmap(lambda x: x, out_dims=("lol",))(tensor) 291*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, msg): 292*da0073e9SAndroid Build Coastguard Worker vmap(lambda x: x, out_dims=None)(tensor) 293*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, msg): 294*da0073e9SAndroid Build Coastguard Worker vmap(lambda x: x, out_dims=(None,))(tensor) 295*da0073e9SAndroid Build Coastguard Worker 296*da0073e9SAndroid Build Coastguard Worker def test_out_dims_and_num_outputs_mismatch_err_msg(self): 297*da0073e9SAndroid Build Coastguard Worker msg = "`out_dims` must have one dim per output" 298*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 3, 5) 299*da0073e9SAndroid Build Coastguard Worker 300*da0073e9SAndroid Build Coastguard Worker # Too many out_dims 301*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, msg): 302*da0073e9SAndroid Build Coastguard Worker vmap(lambda x: x, out_dims=(0, 0))(x) 303*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, msg): 304*da0073e9SAndroid Build Coastguard Worker vmap(lambda x: (x, x, x), out_dims=(0, 0, 0, 0))(x) 305*da0073e9SAndroid Build Coastguard Worker 306*da0073e9SAndroid Build Coastguard Worker # Too few out_dims 307*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, msg): 308*da0073e9SAndroid Build Coastguard Worker vmap(lambda x: (x, x), out_dims=(0,))(x) 309*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, msg): 310*da0073e9SAndroid Build Coastguard Worker vmap(lambda x: (x, x, x), out_dims=(0, 0))(x) 311*da0073e9SAndroid Build Coastguard Worker 312*da0073e9SAndroid Build Coastguard Worker def test_out_dim_out_of_bounds_err_msg(self): 313*da0073e9SAndroid Build Coastguard Worker # TODO(rzou): This error message isn't that great. It comes straight 314*da0073e9SAndroid Build Coastguard Worker # from maybe_wrap_dim. Consider doing a try-catch-(add some context) to 315*da0073e9SAndroid Build Coastguard Worker # the error message in the future in C++ 316*da0073e9SAndroid Build Coastguard Worker msg = "Dimension out of range" 317*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 3, 5) 318*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(IndexError, msg): 319*da0073e9SAndroid Build Coastguard Worker vmap(lambda x: x, out_dims=3)(x) 320*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(IndexError, msg): 321*da0073e9SAndroid Build Coastguard Worker vmap(lambda x: x, out_dims=-4)(x) 322*da0073e9SAndroid Build Coastguard Worker 323*da0073e9SAndroid Build Coastguard Worker def test_non_zero_in_dims(self): 324*da0073e9SAndroid Build Coastguard Worker tensor = torch.randn(2, 3, 5) 325*da0073e9SAndroid Build Coastguard Worker 326*da0073e9SAndroid Build Coastguard Worker # Implicit out_dims = 0; vmap will move the batch dim to the front. 327*da0073e9SAndroid Build Coastguard Worker output = vmap(lambda x: x, (1,))(tensor) 328*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output, tensor.permute(1, 0, 2)) 329*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output.data_ptr(), tensor.data_ptr()) 330*da0073e9SAndroid Build Coastguard Worker 331*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 3) 332*da0073e9SAndroid Build Coastguard Worker y = torch.randn(3, 2) 333*da0073e9SAndroid Build Coastguard Worker output = vmap(torch.mul, (0, 1))(x, y) 334*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output, x * y.t()) 335*da0073e9SAndroid Build Coastguard Worker output = vmap(torch.mul, (1, 0))(x, y) 336*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output, x.t() * y) 337*da0073e9SAndroid Build Coastguard Worker 338*da0073e9SAndroid Build Coastguard Worker def test_none_in_dims(self): 339*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 3) 340*da0073e9SAndroid Build Coastguard Worker y = torch.randn(2, 3) 341*da0073e9SAndroid Build Coastguard Worker 342*da0073e9SAndroid Build Coastguard Worker # None in_dim for a Tensor means we don't map over it 343*da0073e9SAndroid Build Coastguard Worker output = vmap(torch.mul, (0, None))(x, y) 344*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output.shape, (2, 2, 3)) 345*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output, x.view(2, 1, 3) * y) 346*da0073e9SAndroid Build Coastguard Worker 347*da0073e9SAndroid Build Coastguard Worker # None in_dim for non-tensor arguments 348*da0073e9SAndroid Build Coastguard Worker output = vmap(torch.mul, (0, None))(x, 2) 349*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output, x * 2) 350*da0073e9SAndroid Build Coastguard Worker 351*da0073e9SAndroid Build Coastguard Worker def test_nested_non_default_in_dims(self): 352*da0073e9SAndroid Build Coastguard Worker x = torch.rand(5, 2, 3) 353*da0073e9SAndroid Build Coastguard Worker y = torch.rand(3, 5, 2) 354*da0073e9SAndroid Build Coastguard Worker result = vmap(vmap(vmap(torch.mul), (1, 0)), (1, 2))(x, y) 355*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, x.permute(1, 2, 0) * y.permute(2, 0, 1)) 356*da0073e9SAndroid Build Coastguard Worker 357*da0073e9SAndroid Build Coastguard Worker def test_non_default_in_dims_out_dims(self): 358*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 3, 5) 359*da0073e9SAndroid Build Coastguard Worker 360*da0073e9SAndroid Build Coastguard Worker # Same in_dim as out_dim, vmap over identity 361*da0073e9SAndroid Build Coastguard Worker result = vmap(lambda x: x, in_dims=1, out_dims=1)(x) 362*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, x) 363*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.data_ptr(), x.data_ptr()) 364*da0073e9SAndroid Build Coastguard Worker 365*da0073e9SAndroid Build Coastguard Worker # Different in_dim from out_dim, vmap over identity 366*da0073e9SAndroid Build Coastguard Worker result = vmap(lambda x: x, in_dims=2, out_dims=1)(x) 367*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.shape, (2, 5, 3)) 368*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, x.transpose(1, 2)) 369*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.data_ptr(), x.data_ptr()) 370*da0073e9SAndroid Build Coastguard Worker 371*da0073e9SAndroid Build Coastguard Worker def foo(x): 372*da0073e9SAndroid Build Coastguard Worker return x * 2 373*da0073e9SAndroid Build Coastguard Worker 374*da0073e9SAndroid Build Coastguard Worker # Same in_dim as out_dim, vmap over operation 375*da0073e9SAndroid Build Coastguard Worker result = vmap(foo, in_dims=1, out_dims=1)(x) 376*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, x * 2) 377*da0073e9SAndroid Build Coastguard Worker 378*da0073e9SAndroid Build Coastguard Worker # Different in_dim as out_dim, vmap over operation 379*da0073e9SAndroid Build Coastguard Worker result = vmap(foo, in_dims=2, out_dims=1)(x) 380*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.shape, (2, 5, 3)) 381*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, (x * 2).transpose(1, 2)) 382*da0073e9SAndroid Build Coastguard Worker 383*da0073e9SAndroid Build Coastguard Worker # Basic nested test. 384*da0073e9SAndroid Build Coastguard Worker result = vmap(vmap(foo, 1, 1), 1, 1)(x) 385*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, x * 2) 386*da0073e9SAndroid Build Coastguard Worker 387*da0073e9SAndroid Build Coastguard Worker def test_accepts_nested_inputs(self): 388*da0073e9SAndroid Build Coastguard Worker B0 = 2 389*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 3) 390*da0073e9SAndroid Build Coastguard Worker y = torch.randn(2, 3) 391*da0073e9SAndroid Build Coastguard Worker 392*da0073e9SAndroid Build Coastguard Worker # Single layer of nesting 393*da0073e9SAndroid Build Coastguard Worker out = vmap(lambda z: z[0] + z[1])((x, y)) 394*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, x + y) 395*da0073e9SAndroid Build Coastguard Worker out = vmap(lambda z: z[0] + z[1], in_dims=(0,))((x, y)) 396*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, x + y) 397*da0073e9SAndroid Build Coastguard Worker out = vmap(lambda z: z[0] + z[1], in_dims=((0, 0),))((x, y)) 398*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, x + y) 399*da0073e9SAndroid Build Coastguard Worker 400*da0073e9SAndroid Build Coastguard Worker out = vmap(lambda z: z[0] + z[1])([x, y]) 401*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, x + y) 402*da0073e9SAndroid Build Coastguard Worker out = vmap(lambda z: z[0] + z[1], in_dims=(0,))([x, y]) 403*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, x + y) 404*da0073e9SAndroid Build Coastguard Worker out = vmap(lambda z: z[0] + z[1], in_dims=([0, 0],))([x, y]) 405*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, x + y) 406*da0073e9SAndroid Build Coastguard Worker 407*da0073e9SAndroid Build Coastguard Worker out = vmap(lambda z: z["x"] + z["y"])({"x": x, "y": y}) 408*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, x + y) 409*da0073e9SAndroid Build Coastguard Worker out = vmap(lambda z: z["x"] + z["y"], in_dims=(0,))({"x": x, "y": y}) 410*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, x + y) 411*da0073e9SAndroid Build Coastguard Worker out = vmap(lambda z: z["x"] + z["y"], in_dims=({"x": 0, "y": 0},))( 412*da0073e9SAndroid Build Coastguard Worker {"x": x, "y": y} 413*da0073e9SAndroid Build Coastguard Worker ) 414*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, x + y) 415*da0073e9SAndroid Build Coastguard Worker 416*da0073e9SAndroid Build Coastguard Worker # Multiple layers of nesting 417*da0073e9SAndroid Build Coastguard Worker out_fn = vmap(lambda z: z["x"][0] + z["x"][1][0] + z["y"][0] + z["y"][1]) 418*da0073e9SAndroid Build Coastguard Worker out = out_fn({"x": [x, (x,)], "y": [y, y]}) 419*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, x + x + y + y) 420*da0073e9SAndroid Build Coastguard Worker 421*da0073e9SAndroid Build Coastguard Worker def test_in_dims_wrong_type_err_msg(self): 422*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3) 423*da0073e9SAndroid Build Coastguard Worker y = torch.randn(3) 424*da0073e9SAndroid Build Coastguard Worker msg = r"expected `in_dims` to be int or a \(potentially nested\) tuple" 425*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, msg): 426*da0073e9SAndroid Build Coastguard Worker vmap(torch.mul, [0, 0])(x, y) 427*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, msg): 428*da0073e9SAndroid Build Coastguard Worker vmap(torch.mul, set({0}))(x, y) 429*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, msg): 430*da0073e9SAndroid Build Coastguard Worker vmap(torch.mul, "lol")(x, y) 431*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, msg): 432*da0073e9SAndroid Build Coastguard Worker vmap(lambda z: z[0] + z[1], in_dims=[0, 0])([x, y]) 433*da0073e9SAndroid Build Coastguard Worker # The following should not throw 434*da0073e9SAndroid Build Coastguard Worker vmap(torch.mul, (0, 0))(x, y) 435*da0073e9SAndroid Build Coastguard Worker 436*da0073e9SAndroid Build Coastguard Worker def test_not_enough_in_dims_err_msg(self): 437*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3) 438*da0073e9SAndroid Build Coastguard Worker y = torch.randn(3) 439*da0073e9SAndroid Build Coastguard Worker msg = r"in_dims is not compatible with the structure of `inputs`" 440*da0073e9SAndroid Build Coastguard Worker 441*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, msg): 442*da0073e9SAndroid Build Coastguard Worker vmap(torch.mul, (0,))(x, y) 443*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, msg): 444*da0073e9SAndroid Build Coastguard Worker vmap(torch.mul, (0, 0, 0))(x, y) 445*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, msg): 446*da0073e9SAndroid Build Coastguard Worker vmap(lambda z: z[0] + z[1], in_dims=([0],))([x, y]) 447*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, msg): 448*da0073e9SAndroid Build Coastguard Worker vmap(lambda z: z[0] + z[1], in_dims=((0, 0),))([x, y]) 449*da0073e9SAndroid Build Coastguard Worker # The following should not throw 450*da0073e9SAndroid Build Coastguard Worker vmap(torch.mul, (0, 0))(x, y) 451*da0073e9SAndroid Build Coastguard Worker 452*da0073e9SAndroid Build Coastguard Worker def test_integer_in_dim_but_not_tensor_input_err_msg(self): 453*da0073e9SAndroid Build Coastguard Worker def foo(xy): 454*da0073e9SAndroid Build Coastguard Worker return xy[0] * xy[1] 455*da0073e9SAndroid Build Coastguard Worker 456*da0073e9SAndroid Build Coastguard Worker def bar(x, yz): 457*da0073e9SAndroid Build Coastguard Worker return x * yz[0] * yz[1] 458*da0073e9SAndroid Build Coastguard Worker 459*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 3) 460*da0073e9SAndroid Build Coastguard Worker y = torch.randn(2, 3) 461*da0073e9SAndroid Build Coastguard Worker 462*da0073e9SAndroid Build Coastguard Worker # the following are errors in jax (and will always be errors) 463*da0073e9SAndroid Build Coastguard Worker msg = "Got in_dim=0 for an input but the input is of type" 464*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, msg): 465*da0073e9SAndroid Build Coastguard Worker vmap(torch.sum)(x, 0) 466*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, msg): 467*da0073e9SAndroid Build Coastguard Worker vmap(torch.sum, (0, 0))(x, 0) 468*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, msg): 469*da0073e9SAndroid Build Coastguard Worker vmap(lambda z: z[0] + z[1], in_dims=([0, 0],))([x, 1]) 470*da0073e9SAndroid Build Coastguard Worker # The following should not throw 471*da0073e9SAndroid Build Coastguard Worker vmap(torch.sum, (0, None))(x, 0) 472*da0073e9SAndroid Build Coastguard Worker 473*da0073e9SAndroid Build Coastguard Worker def test_in_dim_not_in_tensor_err_msg(self): 474*da0073e9SAndroid Build Coastguard Worker def foo(x): 475*da0073e9SAndroid Build Coastguard Worker return x * x 476*da0073e9SAndroid Build Coastguard Worker 477*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 3) 478*da0073e9SAndroid Build Coastguard Worker y = torch.randn(2, 3) 479*da0073e9SAndroid Build Coastguard Worker 480*da0073e9SAndroid Build Coastguard Worker msg = r"Got in_dim=-?\w for some input, but that input is a Tensor of dimensionality \w" 481*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, msg): 482*da0073e9SAndroid Build Coastguard Worker vmap(foo)(torch.randn([])) 483*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, msg): 484*da0073e9SAndroid Build Coastguard Worker vmap(foo, in_dims=(0,))(torch.randn([])) 485*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, msg): 486*da0073e9SAndroid Build Coastguard Worker vmap(foo, in_dims=(-1,))(x) 487*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, msg): 488*da0073e9SAndroid Build Coastguard Worker vmap(foo, in_dims=(2,))(y) 489*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, msg): 490*da0073e9SAndroid Build Coastguard Worker vmap(lambda z: z[0] + z[1], in_dims=([3, 0],))([x, y]) 491*da0073e9SAndroid Build Coastguard Worker # the following should not throw 492*da0073e9SAndroid Build Coastguard Worker vmap(foo, in_dims=(0,))(torch.randn(2, 3)) 493*da0073e9SAndroid Build Coastguard Worker vmap(foo, in_dims=(1,))(torch.randn(2, 3)) 494*da0073e9SAndroid Build Coastguard Worker 495*da0073e9SAndroid Build Coastguard Worker def test_fallback_does_not_warn_by_default(self): 496*da0073e9SAndroid Build Coastguard Worker # NB: One day we will implement a batching rule for torch.atan2. 497*da0073e9SAndroid Build Coastguard Worker # If/when we do, this test should be replaced to test the fallback 498*da0073e9SAndroid Build Coastguard Worker # path on another operator to avoid bitrot. 499*da0073e9SAndroid Build Coastguard Worker op = torch.atan2 500*da0073e9SAndroid Build Coastguard Worker x = torch.randn(11) 501*da0073e9SAndroid Build Coastguard Worker y = torch.randn(11) 502*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as wa: 503*da0073e9SAndroid Build Coastguard Worker result = vmap(op)(x, y) 504*da0073e9SAndroid Build Coastguard Worker # The single warning here is the "vmap is experimental" 505*da0073e9SAndroid Build Coastguard Worker # warning, not a warning from the vmap fallback path. 506*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(wa), 1) 507*da0073e9SAndroid Build Coastguard Worker 508*da0073e9SAndroid Build Coastguard Worker def test_fallback_warns_when_warnings_are_enabled(self): 509*da0073e9SAndroid Build Coastguard Worker # NB: One day we will implement a batching rule for torch.atan2. 510*da0073e9SAndroid Build Coastguard Worker # If/when we do, this test should be replaced to test the fallback 511*da0073e9SAndroid Build Coastguard Worker # path on another operator to avoid bitrot. 512*da0073e9SAndroid Build Coastguard Worker op = torch.atan2 513*da0073e9SAndroid Build Coastguard Worker x = torch.randn(11) 514*da0073e9SAndroid Build Coastguard Worker y = torch.randn(11) 515*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as wa: 516*da0073e9SAndroid Build Coastguard Worker with EnableVmapFallbackWarnings(): 517*da0073e9SAndroid Build Coastguard Worker result = vmap(op)(x, y) 518*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(wa), 2) 519*da0073e9SAndroid Build Coastguard Worker self.assertRegex(str(wa[-1].message), FALLBACK_REGEX) 520*da0073e9SAndroid Build Coastguard Worker 521*da0073e9SAndroid Build Coastguard Worker def _assert_uses_vmap_fallback(self, vmap_args, inputs): 522*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as wa: 523*da0073e9SAndroid Build Coastguard Worker with EnableVmapFallbackWarnings(): 524*da0073e9SAndroid Build Coastguard Worker result = vmap(*vmap_args)(*inputs) 525*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(wa), 2) 526*da0073e9SAndroid Build Coastguard Worker self.assertRegex(str(wa[-1].message), FALLBACK_REGEX) 527*da0073e9SAndroid Build Coastguard Worker 528*da0073e9SAndroid Build Coastguard Worker def test_fallback_zero_dim(self): 529*da0073e9SAndroid Build Coastguard Worker # NB: One day we will implement a batching rule for torch.atan2. 530*da0073e9SAndroid Build Coastguard Worker # If/when we do, this test should be replaced to test the fallback 531*da0073e9SAndroid Build Coastguard Worker # path on another operator to avoid bitrot. 532*da0073e9SAndroid Build Coastguard Worker op = torch.atan2 533*da0073e9SAndroid Build Coastguard Worker x = torch.randn(11) 534*da0073e9SAndroid Build Coastguard Worker y = torch.randn(11) 535*da0073e9SAndroid Build Coastguard Worker self._assert_uses_vmap_fallback((op,), (x, y)) 536*da0073e9SAndroid Build Coastguard Worker 537*da0073e9SAndroid Build Coastguard Worker B0, B1 = 0, 3 538*da0073e9SAndroid Build Coastguard Worker x = torch.randn(B0, 11) 539*da0073e9SAndroid Build Coastguard Worker y = torch.randn(11) 540*da0073e9SAndroid Build Coastguard Worker 541*da0073e9SAndroid Build Coastguard Worker msg = "The fallback path does not support vmap over dims of size 0" 542*da0073e9SAndroid Build Coastguard Worker 543*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, msg): 544*da0073e9SAndroid Build Coastguard Worker vmap(op, (0, None))(x, y) 545*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, msg): 546*da0073e9SAndroid Build Coastguard Worker vmap(op, (None, 0))(y, x) 547*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, msg): 548*da0073e9SAndroid Build Coastguard Worker vmap(op)(x, x) 549*da0073e9SAndroid Build Coastguard Worker 550*da0073e9SAndroid Build Coastguard Worker x = torch.randn(B0, B1, 11) 551*da0073e9SAndroid Build Coastguard Worker y = torch.randn(B1, 11) 552*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, msg): 553*da0073e9SAndroid Build Coastguard Worker vmap(op, (0, None))(x, y) 554*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, msg): 555*da0073e9SAndroid Build Coastguard Worker vmap(op, (None, 0))(y, x) 556*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, msg): 557*da0073e9SAndroid Build Coastguard Worker vmap(op)(x, x) 558*da0073e9SAndroid Build Coastguard Worker 559*da0073e9SAndroid Build Coastguard Worker def test_fallback_atan2(self): 560*da0073e9SAndroid Build Coastguard Worker # NB: One day we will implement a batching rule for torch.atan2. 561*da0073e9SAndroid Build Coastguard Worker # If/when we do, this test should be replaced to test the fallback 562*da0073e9SAndroid Build Coastguard Worker # path on another operator to avoid bitrot. 563*da0073e9SAndroid Build Coastguard Worker op = torch.atan2 564*da0073e9SAndroid Build Coastguard Worker 565*da0073e9SAndroid Build Coastguard Worker x = torch.randn(5, 7, 11) 566*da0073e9SAndroid Build Coastguard Worker y = torch.randn(5, 7, 11) 567*da0073e9SAndroid Build Coastguard Worker 568*da0073e9SAndroid Build Coastguard Worker self._assert_uses_vmap_fallback((op,), (x, y)) 569*da0073e9SAndroid Build Coastguard Worker 570*da0073e9SAndroid Build Coastguard Worker # fallback on torch.atan2 571*da0073e9SAndroid Build Coastguard Worker x = torch.randn(7, 11, 5) 572*da0073e9SAndroid Build Coastguard Worker y = torch.randn(5, 7, 11) 573*da0073e9SAndroid Build Coastguard Worker result = vmap(op, (2, 0))(x, y) 574*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, op(x.permute(2, 0, 1), y)) 575*da0073e9SAndroid Build Coastguard Worker 576*da0073e9SAndroid Build Coastguard Worker # fallback on torch.atan2, nested vmap 577*da0073e9SAndroid Build Coastguard Worker x = torch.randn(7, 11, 5) 578*da0073e9SAndroid Build Coastguard Worker y = torch.randn(5, 7, 11) 579*da0073e9SAndroid Build Coastguard Worker result = vmap(vmap(op), (2, 0))(x, y) 580*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, op(x.permute(2, 0, 1), y)) 581*da0073e9SAndroid Build Coastguard Worker 582*da0073e9SAndroid Build Coastguard Worker # big batch size (total 10000) 583*da0073e9SAndroid Build Coastguard Worker x = torch.randn(100, 10, 10, 5) 584*da0073e9SAndroid Build Coastguard Worker y = torch.randn(100, 10, 10) 585*da0073e9SAndroid Build Coastguard Worker result = vmap(vmap(vmap(op)))(x, y) 586*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, op(x, y.view(100, 10, 10, 1))) 587*da0073e9SAndroid Build Coastguard Worker 588*da0073e9SAndroid Build Coastguard Worker def test_fallback_masked_fill(self): 589*da0073e9SAndroid Build Coastguard Worker # NB: One day we will implement a batching rule for masked_fill 590*da0073e9SAndroid Build Coastguard Worker # If/when we do, this test should be replaced to test the fallback 591*da0073e9SAndroid Build Coastguard Worker # path on another operator to avoid bitrot. 592*da0073e9SAndroid Build Coastguard Worker def run_test(batch_size): 593*da0073e9SAndroid Build Coastguard Worker B0 = batch_size 594*da0073e9SAndroid Build Coastguard Worker x = torch.randn(B0, 7, 11, 13) 595*da0073e9SAndroid Build Coastguard Worker dim = 0 596*da0073e9SAndroid Build Coastguard Worker index = torch.tensor([0, 4, 2]) 597*da0073e9SAndroid Build Coastguard Worker values = torch.randn(B0, 3, 11, 13) 598*da0073e9SAndroid Build Coastguard Worker 599*da0073e9SAndroid Build Coastguard Worker self._assert_uses_vmap_fallback( 600*da0073e9SAndroid Build Coastguard Worker (torch.index_add, (0, None, None, 0)), (x, dim, index, values) 601*da0073e9SAndroid Build Coastguard Worker ) 602*da0073e9SAndroid Build Coastguard Worker 603*da0073e9SAndroid Build Coastguard Worker result = vmap(torch.index_add, (0, None, None, 0))(x, dim, index, values) 604*da0073e9SAndroid Build Coastguard Worker expected = torch.index_add(x, dim + 1, index, values.view(B0, 3, 11, 13)) 605*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, expected) 606*da0073e9SAndroid Build Coastguard Worker 607*da0073e9SAndroid Build Coastguard Worker run_test(batch_size=5) 608*da0073e9SAndroid Build Coastguard Worker run_test(batch_size=1237) 609*da0073e9SAndroid Build Coastguard Worker 610*da0073e9SAndroid Build Coastguard Worker def test_fallback_multiple_returns(self): 611*da0073e9SAndroid Build Coastguard Worker # NB: One day we will implement a batching rule for torch.var_mean 612*da0073e9SAndroid Build Coastguard Worker # If/when we do, this test should be replaced to test the fallback 613*da0073e9SAndroid Build Coastguard Worker # path on another operator to avoid bitrot. 614*da0073e9SAndroid Build Coastguard Worker B0, B1, B2 = 2, 3, 1237 615*da0073e9SAndroid Build Coastguard Worker tensor = torch.randn(B0, 10) 616*da0073e9SAndroid Build Coastguard Worker 617*da0073e9SAndroid Build Coastguard Worker self._assert_uses_vmap_fallback((torch.var_mean,), (tensor,)) 618*da0073e9SAndroid Build Coastguard Worker 619*da0073e9SAndroid Build Coastguard Worker # fallback correctness on torch.var_mean 620*da0073e9SAndroid Build Coastguard Worker result = vmap(torch.var_mean)(tensor) 621*da0073e9SAndroid Build Coastguard Worker expected = torch.var_mean(tensor, dim=1) 622*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, expected) 623*da0073e9SAndroid Build Coastguard Worker 624*da0073e9SAndroid Build Coastguard Worker # nested vmap 625*da0073e9SAndroid Build Coastguard Worker tensor = torch.randn(B0, B1, 10) 626*da0073e9SAndroid Build Coastguard Worker result = vmap(vmap(torch.var_mean))(tensor) 627*da0073e9SAndroid Build Coastguard Worker expected = torch.var_mean(tensor, dim=2) 628*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, expected) 629*da0073e9SAndroid Build Coastguard Worker 630*da0073e9SAndroid Build Coastguard Worker # big batch size, nested vmap 631*da0073e9SAndroid Build Coastguard Worker tensor = torch.randn(B0, B1, B2, 10) 632*da0073e9SAndroid Build Coastguard Worker result = vmap(vmap(vmap(torch.var_mean)))(tensor) 633*da0073e9SAndroid Build Coastguard Worker expected = torch.var_mean(tensor, dim=3) 634*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, expected) 635*da0073e9SAndroid Build Coastguard Worker 636*da0073e9SAndroid Build Coastguard Worker def test_inplace_fallback_unary(self): 637*da0073e9SAndroid Build Coastguard Worker # Test the in-place fallback on an in-place method that takes no 638*da0073e9SAndroid Build Coastguard Worker # additional Tensor arguments. This is the simplest case of the fallback. 639*da0073e9SAndroid Build Coastguard Worker # NB: One day we will implement a batching rule for acos_. 640*da0073e9SAndroid Build Coastguard Worker # If/when we do, this test should be replaced to test the fallback 641*da0073e9SAndroid Build Coastguard Worker # path on another operator to avoid bitrot. 642*da0073e9SAndroid Build Coastguard Worker op = Tensor.acos_ 643*da0073e9SAndroid Build Coastguard Worker B0, B1, B2 = 2, 3, 10000 644*da0073e9SAndroid Build Coastguard Worker 645*da0073e9SAndroid Build Coastguard Worker x = torch.randn(B0, 5) 646*da0073e9SAndroid Build Coastguard Worker self._assert_uses_vmap_fallback((op,), (x,)) 647*da0073e9SAndroid Build Coastguard Worker 648*da0073e9SAndroid Build Coastguard Worker # Single vmap 649*da0073e9SAndroid Build Coastguard Worker x_orig = torch.rand(B0, 5) 650*da0073e9SAndroid Build Coastguard Worker x = x_orig.clone() 651*da0073e9SAndroid Build Coastguard Worker result = vmap(op)(x) 652*da0073e9SAndroid Build Coastguard Worker self.assertTrue(result is x) 653*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, x_orig.acos()) 654*da0073e9SAndroid Build Coastguard Worker 655*da0073e9SAndroid Build Coastguard Worker # Single vmap + different out_dim produces a view(!) 656*da0073e9SAndroid Build Coastguard Worker x_orig = torch.rand(B0, 5) 657*da0073e9SAndroid Build Coastguard Worker x = x_orig.clone() 658*da0073e9SAndroid Build Coastguard Worker result = vmap(op, out_dims=(1,))(x) 659*da0073e9SAndroid Build Coastguard Worker self.assertTrue(result._base is x) 660*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, x_orig.t().acos()) 661*da0073e9SAndroid Build Coastguard Worker 662*da0073e9SAndroid Build Coastguard Worker # Nested vmap 663*da0073e9SAndroid Build Coastguard Worker x_orig = torch.randn(B0, B1, 5) 664*da0073e9SAndroid Build Coastguard Worker x = x_orig.clone() 665*da0073e9SAndroid Build Coastguard Worker result = vmap(vmap(op))(x) 666*da0073e9SAndroid Build Coastguard Worker self.assertTrue(result is x) 667*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, x_orig.acos()) 668*da0073e9SAndroid Build Coastguard Worker 669*da0073e9SAndroid Build Coastguard Worker # Nested vmap, large batch size 670*da0073e9SAndroid Build Coastguard Worker x_orig = torch.randn(B0, B1, B2, 5) 671*da0073e9SAndroid Build Coastguard Worker x = x_orig.clone() 672*da0073e9SAndroid Build Coastguard Worker result = vmap(vmap(vmap(op)))(x) 673*da0073e9SAndroid Build Coastguard Worker self.assertTrue(result is x) 674*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, x_orig.acos()) 675*da0073e9SAndroid Build Coastguard Worker 676*da0073e9SAndroid Build Coastguard Worker def test_inplace_fallback_nary_same_levels(self): 677*da0073e9SAndroid Build Coastguard Worker # NB: One day we will implement a batching rule for atan2_ 678*da0073e9SAndroid Build Coastguard Worker # If/when we do, this test should be replaced to test the fallback 679*da0073e9SAndroid Build Coastguard Worker # path on another operator to avoid bitrot. 680*da0073e9SAndroid Build Coastguard Worker op = Tensor.atan2_ 681*da0073e9SAndroid Build Coastguard Worker outplace_op = torch.atan2 682*da0073e9SAndroid Build Coastguard Worker 683*da0073e9SAndroid Build Coastguard Worker x = torch.randn(5, 7, 11) 684*da0073e9SAndroid Build Coastguard Worker y = torch.randn(5, 7, 11) 685*da0073e9SAndroid Build Coastguard Worker self._assert_uses_vmap_fallback((op,), (x, y)) 686*da0073e9SAndroid Build Coastguard Worker 687*da0073e9SAndroid Build Coastguard Worker # Single vmap 688*da0073e9SAndroid Build Coastguard Worker B0 = 5 689*da0073e9SAndroid Build Coastguard Worker x_orig = torch.randn(7, 11, B0) 690*da0073e9SAndroid Build Coastguard Worker x = x_orig.clone() 691*da0073e9SAndroid Build Coastguard Worker y = torch.randn(B0, 7, 11) 692*da0073e9SAndroid Build Coastguard Worker vmap(op, (2, 0))(x, y) 693*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, outplace_op(x_orig, y.movedim(0, 2))) 694*da0073e9SAndroid Build Coastguard Worker 695*da0073e9SAndroid Build Coastguard Worker # Nested vmap 696*da0073e9SAndroid Build Coastguard Worker B0, B1 = 5, 7 697*da0073e9SAndroid Build Coastguard Worker x_orig = torch.randn(B1, 11, B0) 698*da0073e9SAndroid Build Coastguard Worker x = x_orig.clone() 699*da0073e9SAndroid Build Coastguard Worker y = torch.randn(B0, B1, 11) 700*da0073e9SAndroid Build Coastguard Worker vmap(vmap(op), (2, 0))(x, y) 701*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, outplace_op(x_orig, y.movedim([0, 1], [2, 0]))) 702*da0073e9SAndroid Build Coastguard Worker 703*da0073e9SAndroid Build Coastguard Worker # big batch size (total 10000) 704*da0073e9SAndroid Build Coastguard Worker B0, B1, B2 = 100, 10, 10 705*da0073e9SAndroid Build Coastguard Worker x_orig = torch.randn(B0, B1, B2, 5) 706*da0073e9SAndroid Build Coastguard Worker x = x_orig.clone() 707*da0073e9SAndroid Build Coastguard Worker y = torch.randn(B0, B1, B2) 708*da0073e9SAndroid Build Coastguard Worker result = vmap(vmap(vmap(op)))(x, y) 709*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, outplace_op(x_orig, y.view(B0, B1, B2, 1))) 710*da0073e9SAndroid Build Coastguard Worker 711*da0073e9SAndroid Build Coastguard Worker def test_inplace_fallback_nary_different_levels(self): 712*da0073e9SAndroid Build Coastguard Worker # NB: One day we will implement a batching rule for atan2_ 713*da0073e9SAndroid Build Coastguard Worker # If/when we do, this test should be replaced to test the fallback 714*da0073e9SAndroid Build Coastguard Worker # path on another operator to avoid bitrot. 715*da0073e9SAndroid Build Coastguard Worker op = Tensor.atan2_ 716*da0073e9SAndroid Build Coastguard Worker outplace_op = torch.atan2 717*da0073e9SAndroid Build Coastguard Worker B0, B1, B2 = 2, 3, 5 718*da0073e9SAndroid Build Coastguard Worker 719*da0073e9SAndroid Build Coastguard Worker x = torch.rand(B0, 7) 720*da0073e9SAndroid Build Coastguard Worker y = torch.rand(7) 721*da0073e9SAndroid Build Coastguard Worker self._assert_uses_vmap_fallback((op, (0, None)), (x, y)) 722*da0073e9SAndroid Build Coastguard Worker 723*da0073e9SAndroid Build Coastguard Worker # op(left, right): All of the levels in right are found in left 724*da0073e9SAndroid Build Coastguard Worker x_orig = torch.rand(B0, 7) 725*da0073e9SAndroid Build Coastguard Worker x = x_orig.clone() 726*da0073e9SAndroid Build Coastguard Worker y = torch.rand(7) 727*da0073e9SAndroid Build Coastguard Worker vmap(op, in_dims=(0, None))(x, y) 728*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, outplace_op(x_orig, y)) 729*da0073e9SAndroid Build Coastguard Worker 730*da0073e9SAndroid Build Coastguard Worker x_orig = torch.rand(B0, B1, 7) 731*da0073e9SAndroid Build Coastguard Worker x = x_orig.clone() 732*da0073e9SAndroid Build Coastguard Worker y = torch.rand(B0, 7) 733*da0073e9SAndroid Build Coastguard Worker vmap(vmap(op, in_dims=(0, None)))(x, y) 734*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, outplace_op(x_orig, y.view(B0, 1, 7))) 735*da0073e9SAndroid Build Coastguard Worker 736*da0073e9SAndroid Build Coastguard Worker # op(left, right): Some of the levels in right are not found in left 737*da0073e9SAndroid Build Coastguard Worker msg = r"vmap: aten::atan2_\(self, \*extra_args\) is not possible" 738*da0073e9SAndroid Build Coastguard Worker x = torch.rand(7) 739*da0073e9SAndroid Build Coastguard Worker y = torch.rand(B0, 7) 740*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, msg): 741*da0073e9SAndroid Build Coastguard Worker vmap(op, in_dims=(None, 0))(x, y) 742*da0073e9SAndroid Build Coastguard Worker 743*da0073e9SAndroid Build Coastguard Worker x = torch.rand(B1, 7) 744*da0073e9SAndroid Build Coastguard Worker y = torch.rand(B0, 7) 745*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, msg): 746*da0073e9SAndroid Build Coastguard Worker vmap(vmap(op, in_dims=(0, None)), in_dims=(None, 0))(x, y) 747*da0073e9SAndroid Build Coastguard Worker 748*da0073e9SAndroid Build Coastguard Worker x = torch.rand(B1, 7) 749*da0073e9SAndroid Build Coastguard Worker y = torch.rand(7, B0) 750*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, msg): 751*da0073e9SAndroid Build Coastguard Worker vmap(vmap(op, in_dims=(0, None)), in_dims=(None, 1))(x, y) 752*da0073e9SAndroid Build Coastguard Worker 753*da0073e9SAndroid Build Coastguard Worker x = torch.rand(B0, 7) 754*da0073e9SAndroid Build Coastguard Worker y = torch.rand(B0, B1, 7) 755*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, msg): 756*da0073e9SAndroid Build Coastguard Worker vmap(vmap(op, in_dims=(None, 0)))(x, y) 757*da0073e9SAndroid Build Coastguard Worker 758*da0073e9SAndroid Build Coastguard Worker def test_backward_unsupported_interaction(self): 759*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, requires_grad=True) 760*da0073e9SAndroid Build Coastguard Worker y = torch.randn(5) 761*da0073e9SAndroid Build Coastguard Worker grad = torch.randn_like(x) 762*da0073e9SAndroid Build Coastguard Worker err_msg = r"backward\(\) called inside torch.vmap" 763*da0073e9SAndroid Build Coastguard Worker 764*da0073e9SAndroid Build Coastguard Worker def backward_on_vmapped_tensor(x): 765*da0073e9SAndroid Build Coastguard Worker x.sum().backward() 766*da0073e9SAndroid Build Coastguard Worker 767*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, err_msg): 768*da0073e9SAndroid Build Coastguard Worker vmap(backward_on_vmapped_tensor)(x) 769*da0073e9SAndroid Build Coastguard Worker 770*da0073e9SAndroid Build Coastguard Worker def backward_with_vmapped_grad(x, grad): 771*da0073e9SAndroid Build Coastguard Worker x.backward(grad) 772*da0073e9SAndroid Build Coastguard Worker 773*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, err_msg): 774*da0073e9SAndroid Build Coastguard Worker vmap(backward_with_vmapped_grad)(x, grad) 775*da0073e9SAndroid Build Coastguard Worker 776*da0073e9SAndroid Build Coastguard Worker def completely_unrelated_backward(y): 777*da0073e9SAndroid Build Coastguard Worker x.sum().backward() 778*da0073e9SAndroid Build Coastguard Worker 779*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, err_msg): 780*da0073e9SAndroid Build Coastguard Worker vmap(completely_unrelated_backward)(y) 781*da0073e9SAndroid Build Coastguard Worker 782*da0073e9SAndroid Build Coastguard Worker def test_grad_unsupported_interaction(self): 783*da0073e9SAndroid Build Coastguard Worker input_tensor = torch.randn(3, requires_grad=True) 784*da0073e9SAndroid Build Coastguard Worker err_msg = "autograd.grad.* called inside torch.vmap" 785*da0073e9SAndroid Build Coastguard Worker 786*da0073e9SAndroid Build Coastguard Worker captured = torch.randn(3, requires_grad=True) 787*da0073e9SAndroid Build Coastguard Worker 788*da0073e9SAndroid Build Coastguard Worker def output_to_grad_is_vmapped(input_tensor): 789*da0073e9SAndroid Build Coastguard Worker output = (captured * input_tensor).sum() 790*da0073e9SAndroid Build Coastguard Worker return torch.autograd.grad([output], [captured])[0] 791*da0073e9SAndroid Build Coastguard Worker 792*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, err_msg): 793*da0073e9SAndroid Build Coastguard Worker vmap(output_to_grad_is_vmapped)(input_tensor) 794*da0073e9SAndroid Build Coastguard Worker 795*da0073e9SAndroid Build Coastguard Worker output = (input_tensor**2).sum() 796*da0073e9SAndroid Build Coastguard Worker 797*da0073e9SAndroid Build Coastguard Worker def input_to_grad_is_vmapped(input_tensor): 798*da0073e9SAndroid Build Coastguard Worker return torch.autograd.grad([output], [input_tensor])[0] 799*da0073e9SAndroid Build Coastguard Worker 800*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, err_msg): 801*da0073e9SAndroid Build Coastguard Worker vmap(input_to_grad_is_vmapped)(input_tensor) 802*da0073e9SAndroid Build Coastguard Worker 803*da0073e9SAndroid Build Coastguard Worker def test_batched_gradient_basic(self): 804*da0073e9SAndroid Build Coastguard Worker N = 3 805*da0073e9SAndroid Build Coastguard Worker x = torch.randn(N, requires_grad=True) 806*da0073e9SAndroid Build Coastguard Worker y = torch.randn(N) 807*da0073e9SAndroid Build Coastguard Worker 808*da0073e9SAndroid Build Coastguard Worker def vjp_mul(v): 809*da0073e9SAndroid Build Coastguard Worker return torch.autograd.grad([x * y], [x], grad_outputs=[v])[0] 810*da0073e9SAndroid Build Coastguard Worker 811*da0073e9SAndroid Build Coastguard Worker batched_v = torch.eye(N) 812*da0073e9SAndroid Build Coastguard Worker jacobian = vmap(vjp_mul)(batched_v) 813*da0073e9SAndroid Build Coastguard Worker self.assertEqual(jacobian, torch.diagflat(y)) 814*da0073e9SAndroid Build Coastguard Worker 815*da0073e9SAndroid Build Coastguard Worker def test_functools_partial(self): 816*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3) 817*da0073e9SAndroid Build Coastguard Worker y = torch.randn(2, 3) 818*da0073e9SAndroid Build Coastguard Worker result = vmap(functools.partial(torch.mul, x))(y) 819*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, x * y) 820*da0073e9SAndroid Build Coastguard Worker 821*da0073e9SAndroid Build Coastguard Worker def test_nn_module(self): 822*da0073e9SAndroid Build Coastguard Worker tensor = torch.randn(2, 3) 823*da0073e9SAndroid Build Coastguard Worker model = torch.nn.Linear(3, 3, bias=False) 824*da0073e9SAndroid Build Coastguard Worker result = vmap(model)(tensor) 825*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, model(tensor)) 826*da0073e9SAndroid Build Coastguard Worker 827*da0073e9SAndroid Build Coastguard Worker def test_fallback_with_undefined_grad(self): 828*da0073e9SAndroid Build Coastguard Worker B0 = 7 829*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 3, 4, 5, requires_grad=True) 830*da0073e9SAndroid Build Coastguard Worker weight = torch.randn(3, 3, 1, 1) 831*da0073e9SAndroid Build Coastguard Worker v = torch.randn(B0, 2, 3, 4, 5) 832*da0073e9SAndroid Build Coastguard Worker 833*da0073e9SAndroid Build Coastguard Worker def get_vjp(v): 834*da0073e9SAndroid Build Coastguard Worker result = torch.nn.functional.conv2d(x, weight) 835*da0073e9SAndroid Build Coastguard Worker (grad_x,) = torch.autograd.grad(result, x, v) 836*da0073e9SAndroid Build Coastguard Worker return grad_x 837*da0073e9SAndroid Build Coastguard Worker 838*da0073e9SAndroid Build Coastguard Worker # Runs vmap(get_vjp)(v), which should not error out. 839*da0073e9SAndroid Build Coastguard Worker # The backward formula for convolution returns an undefined 840*da0073e9SAndroid Build Coastguard Worker # Tensor for grad_bias because the original bias does not exist. 841*da0073e9SAndroid Build Coastguard Worker # 842*da0073e9SAndroid Build Coastguard Worker # In the future we'll probably add a batching rule for convolution 843*da0073e9SAndroid Build Coastguard Worker # backward. When this happens, we should modify this test to use a 844*da0073e9SAndroid Build Coastguard Worker # different op (and/or create and use a dummy operator) to avoid bitrot. 845*da0073e9SAndroid Build Coastguard Worker self._assert_uses_vmap_fallback([get_vjp], [v]) 846*da0073e9SAndroid Build Coastguard Worker 847*da0073e9SAndroid Build Coastguard Worker 848*da0073e9SAndroid Build Coastguard Workerdef slice_inputs(inputs, bdims, i): 849*da0073e9SAndroid Build Coastguard Worker result = [] 850*da0073e9SAndroid Build Coastguard Worker for inp, bdim in zip(inputs, bdims): 851*da0073e9SAndroid Build Coastguard Worker if bdim is None: 852*da0073e9SAndroid Build Coastguard Worker result.append(inp) 853*da0073e9SAndroid Build Coastguard Worker else: 854*da0073e9SAndroid Build Coastguard Worker result.append(inp.select(bdim, i)) 855*da0073e9SAndroid Build Coastguard Worker return tuple(result) 856*da0073e9SAndroid Build Coastguard Worker 857*da0073e9SAndroid Build Coastguard Worker 858*da0073e9SAndroid Build Coastguard Workerdef reference_vmap(op, inputs, in_dims=0, out_dims=0): 859*da0073e9SAndroid Build Coastguard Worker if isinstance(in_dims, int): 860*da0073e9SAndroid Build Coastguard Worker in_dims = (in_dims,) * len(inputs) 861*da0073e9SAndroid Build Coastguard Worker bdim_sizes = [inp.size(dim) for inp, dim in zip(inputs, in_dims) if dim is not None] 862*da0073e9SAndroid Build Coastguard Worker assert all(bdim_size == bdim_sizes[0] for bdim_size in bdim_sizes) 863*da0073e9SAndroid Build Coastguard Worker bdim_size = bdim_sizes[0] 864*da0073e9SAndroid Build Coastguard Worker results = tuple(op(*slice_inputs(inputs, in_dims, i)) for i in range(bdim_size)) 865*da0073e9SAndroid Build Coastguard Worker 866*da0073e9SAndroid Build Coastguard Worker assert len(results) > 0 867*da0073e9SAndroid Build Coastguard Worker op_has_single_return = not isinstance(results[0], tuple) 868*da0073e9SAndroid Build Coastguard Worker if op_has_single_return: 869*da0073e9SAndroid Build Coastguard Worker assert all(isinstance(result, torch.Tensor) for result in results) 870*da0073e9SAndroid Build Coastguard Worker if isinstance(out_dims, int): 871*da0073e9SAndroid Build Coastguard Worker out_dims = (out_dims,) * 1 872*da0073e9SAndroid Build Coastguard Worker return torch.stack(results, dim=out_dims[0]) 873*da0073e9SAndroid Build Coastguard Worker 874*da0073e9SAndroid Build Coastguard Worker assert all(isinstance(result, tuple) for result in results) 875*da0073e9SAndroid Build Coastguard Worker num_returns = len(results[0]) 876*da0073e9SAndroid Build Coastguard Worker assert all(len(result) == num_returns for result in results) 877*da0073e9SAndroid Build Coastguard Worker if isinstance(out_dims, int): 878*da0073e9SAndroid Build Coastguard Worker out_dims = (out_dims,) * num_returns 879*da0073e9SAndroid Build Coastguard Worker return tuple( 880*da0073e9SAndroid Build Coastguard Worker torch.stack(result_shards, out_dim) 881*da0073e9SAndroid Build Coastguard Worker for result_shards, out_dim in zip(zip(*results), out_dims) 882*da0073e9SAndroid Build Coastguard Worker ) 883*da0073e9SAndroid Build Coastguard Worker 884*da0073e9SAndroid Build Coastguard Worker 885*da0073e9SAndroid Build Coastguard Workerclass TensorFactory: 886*da0073e9SAndroid Build Coastguard Worker @staticmethod 887*da0073e9SAndroid Build Coastguard Worker def rand(size, device="cpu", dtype=torch.float): 888*da0073e9SAndroid Build Coastguard Worker return torch.rand(size, device=device, dtype=dtype) 889*da0073e9SAndroid Build Coastguard Worker 890*da0073e9SAndroid Build Coastguard Worker @staticmethod 891*da0073e9SAndroid Build Coastguard Worker def randn(size, device="cpu", dtype=torch.float): 892*da0073e9SAndroid Build Coastguard Worker return torch.randn(size, device=device, dtype=dtype) 893*da0073e9SAndroid Build Coastguard Worker 894*da0073e9SAndroid Build Coastguard Worker @staticmethod 895*da0073e9SAndroid Build Coastguard Worker def randp1(size, device="cpu", dtype=torch.float): 896*da0073e9SAndroid Build Coastguard Worker return torch.rand(size, device=device, dtype=dtype) + 1 897*da0073e9SAndroid Build Coastguard Worker 898*da0073e9SAndroid Build Coastguard Worker 899*da0073e9SAndroid Build Coastguard Worker# Tests vmap(op, in_dims, out_dims)(*inputs) by comparing the output to a 900*da0073e9SAndroid Build Coastguard Worker# (slow) sequential map+stack fallback. 901*da0073e9SAndroid Build Coastguard Worker# 902*da0073e9SAndroid Build Coastguard Worker# check_view: Test if the first returned output is a view of the first input 903*da0073e9SAndroid Build Coastguard Worker# check_propagates_grad: Test if the operation propagates gradients. 904*da0073e9SAndroid Build Coastguard Workerdef _vmap_test( 905*da0073e9SAndroid Build Coastguard Worker self, 906*da0073e9SAndroid Build Coastguard Worker op, 907*da0073e9SAndroid Build Coastguard Worker inputs, 908*da0073e9SAndroid Build Coastguard Worker in_dims=0, 909*da0073e9SAndroid Build Coastguard Worker out_dims=0, 910*da0073e9SAndroid Build Coastguard Worker check_view=False, 911*da0073e9SAndroid Build Coastguard Worker check_propagates_grad=True, 912*da0073e9SAndroid Build Coastguard Worker): 913*da0073e9SAndroid Build Coastguard Worker result = vmap(op, in_dims, out_dims)(*inputs) 914*da0073e9SAndroid Build Coastguard Worker reference_result = reference_vmap(op, inputs, in_dims, out_dims) 915*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, reference_result) 916*da0073e9SAndroid Build Coastguard Worker op_has_single_return = not isinstance(result, tuple) 917*da0073e9SAndroid Build Coastguard Worker 918*da0073e9SAndroid Build Coastguard Worker if check_view: 919*da0073e9SAndroid Build Coastguard Worker result_as_tuple = (result,) if op_has_single_return else result 920*da0073e9SAndroid Build Coastguard Worker for output in result_as_tuple: 921*da0073e9SAndroid Build Coastguard Worker input0_base = inputs[0] if inputs[0]._base is None else inputs[0]._base 922*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 923*da0073e9SAndroid Build Coastguard Worker output._base is input0_base, 924*da0073e9SAndroid Build Coastguard Worker msg="result was not a view of the first input!", 925*da0073e9SAndroid Build Coastguard Worker ) 926*da0073e9SAndroid Build Coastguard Worker 927*da0073e9SAndroid Build Coastguard Worker if not check_propagates_grad: 928*da0073e9SAndroid Build Coastguard Worker return 929*da0073e9SAndroid Build Coastguard Worker # Assuming input[0] is a floating-point tensor. Check if the vmap 930*da0073e9SAndroid Build Coastguard Worker # operation propagates the requires_grad flag to the zeroth output. 931*da0073e9SAndroid Build Coastguard Worker # Some vmap operators are implemented in a way that assumes that 932*da0073e9SAndroid Build Coastguard Worker # they are composite with respect to autograd. If the operator ever is 933*da0073e9SAndroid Build Coastguard Worker # changed to not be composite with respect to autograd, then the 934*da0073e9SAndroid Build Coastguard Worker # following check should fail. 935*da0073e9SAndroid Build Coastguard Worker inputs_clone = list(inputs) 936*da0073e9SAndroid Build Coastguard Worker inputs_clone[0] = inputs[0].clone().requires_grad_() 937*da0073e9SAndroid Build Coastguard Worker result = vmap(op, in_dims, out_dims)(*inputs_clone) 938*da0073e9SAndroid Build Coastguard Worker result_as_tuple = (result,) if op_has_single_return else result 939*da0073e9SAndroid Build Coastguard Worker self.assertTrue(result[0].requires_grad) 940*da0073e9SAndroid Build Coastguard Worker 941*da0073e9SAndroid Build Coastguard Worker 942*da0073e9SAndroid Build Coastguard Workerdef should_allow_vmap_fallback_usage(fn): 943*da0073e9SAndroid Build Coastguard Worker return getattr(fn, "_allow_vmap_fallback_usage", False) 944*da0073e9SAndroid Build Coastguard Worker 945*da0073e9SAndroid Build Coastguard Worker 946*da0073e9SAndroid Build Coastguard Workerdef allowVmapFallbackUsage(fn): 947*da0073e9SAndroid Build Coastguard Worker fn._allow_vmap_fallback_usage = True 948*da0073e9SAndroid Build Coastguard Worker return fn 949*da0073e9SAndroid Build Coastguard Worker 950*da0073e9SAndroid Build Coastguard Worker 951*da0073e9SAndroid Build Coastguard Worker# All tests of TestVmapBaseLegacy check that the slow vmap fallback is never invoked. 952*da0073e9SAndroid Build Coastguard Worker# This is so that we can incrementally add batching rules for operators to 953*da0073e9SAndroid Build Coastguard Worker# replace the slow vmap fallback path for said operators. To skip this check, 954*da0073e9SAndroid Build Coastguard Worker# please use the allowVmapFallbackUsage decorator. 955*da0073e9SAndroid Build Coastguard Worker# 956*da0073e9SAndroid Build Coastguard Worker# NB: Don't add tests to TestVmapBaseLegacy directly, unless you want them to run 957*da0073e9SAndroid Build Coastguard Worker# on every subclass of TestVmapBaseLegacy. Add them to e.g. TestVmapOperators. 958*da0073e9SAndroid Build Coastguard Worker# 959*da0073e9SAndroid Build Coastguard Worker# NB: TestVmapBaseLegacy is a nested class. This prevents test runners from picking 960*da0073e9SAndroid Build Coastguard Worker# it up and running it. 961*da0073e9SAndroid Build Coastguard Workerclass Namespace: 962*da0073e9SAndroid Build Coastguard Worker class TestVmapBaseLegacy(TestCase): 963*da0073e9SAndroid Build Coastguard Worker def __init__(self, method_name="runTest"): 964*da0073e9SAndroid Build Coastguard Worker super().__init__(method_name) 965*da0073e9SAndroid Build Coastguard Worker 966*da0073e9SAndroid Build Coastguard Worker test_method = getattr(self, method_name, None) 967*da0073e9SAndroid Build Coastguard Worker if test_method is None: 968*da0073e9SAndroid Build Coastguard Worker return 969*da0073e9SAndroid Build Coastguard Worker 970*da0073e9SAndroid Build Coastguard Worker if not should_allow_vmap_fallback_usage(test_method): 971*da0073e9SAndroid Build Coastguard Worker setattr( 972*da0073e9SAndroid Build Coastguard Worker self, 973*da0073e9SAndroid Build Coastguard Worker method_name, 974*da0073e9SAndroid Build Coastguard Worker self._wrap_method_with_vmap_fallback_check(test_method), 975*da0073e9SAndroid Build Coastguard Worker ) 976*da0073e9SAndroid Build Coastguard Worker 977*da0073e9SAndroid Build Coastguard Worker def _wrap_method_with_vmap_fallback_check(self, method): 978*da0073e9SAndroid Build Coastguard Worker msg = ( 979*da0073e9SAndroid Build Coastguard Worker "Expected the test to not invoke the vmap fallback path, i.e., " 980*da0073e9SAndroid Build Coastguard Worker "all of the operators being tested in this test should have batching " 981*da0073e9SAndroid Build Coastguard Worker "rules implemented. If you are intentionally testing something to " 982*da0073e9SAndroid Build Coastguard Worker "do with the fallback path, use allowVmapFallbackUsage. Otherwise, " 983*da0073e9SAndroid Build Coastguard Worker "please make sure that batching rules are implemented for the " 984*da0073e9SAndroid Build Coastguard Worker "operator(s) being tested." 985*da0073e9SAndroid Build Coastguard Worker ) 986*da0073e9SAndroid Build Coastguard Worker 987*da0073e9SAndroid Build Coastguard Worker @functools.wraps(method) 988*da0073e9SAndroid Build Coastguard Worker def wrapper(self, *args, **kwargs): 989*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as wa: 990*da0073e9SAndroid Build Coastguard Worker warnings.simplefilter("always") 991*da0073e9SAndroid Build Coastguard Worker with EnableVmapFallbackWarnings(): 992*da0073e9SAndroid Build Coastguard Worker method(*args, **kwargs) 993*da0073e9SAndroid Build Coastguard Worker for captured_warning in wa: 994*da0073e9SAndroid Build Coastguard Worker self.assertNotRegex( 995*da0073e9SAndroid Build Coastguard Worker str(captured_warning.message), FALLBACK_REGEX, msg 996*da0073e9SAndroid Build Coastguard Worker ) 997*da0073e9SAndroid Build Coastguard Worker 998*da0073e9SAndroid Build Coastguard Worker return types.MethodType(wrapper, self) 999*da0073e9SAndroid Build Coastguard Worker 1000*da0073e9SAndroid Build Coastguard Worker @allowVmapFallbackUsage 1001*da0073e9SAndroid Build Coastguard Worker def test_vmap_fallback_check_ok(self): 1002*da0073e9SAndroid Build Coastguard Worker # One day we'll implement a batching rule for torch.var_mean. 1003*da0073e9SAndroid Build Coastguard Worker # When that happens, please change the example to use an 1004*da0073e9SAndroid Build Coastguard Worker # operator that doesn't have a batching rule implemented. 1005*da0073e9SAndroid Build Coastguard Worker op_using_fallback = torch.var_mean 1006*da0073e9SAndroid Build Coastguard Worker vmap(op_using_fallback)(torch.rand(3)) 1007*da0073e9SAndroid Build Coastguard Worker 1008*da0073e9SAndroid Build Coastguard Worker def test_vmap_fallback_check(self): 1009*da0073e9SAndroid Build Coastguard Worker @self._wrap_method_with_vmap_fallback_check 1010*da0073e9SAndroid Build Coastguard Worker def no_fallback(self): 1011*da0073e9SAndroid Build Coastguard Worker pass 1012*da0073e9SAndroid Build Coastguard Worker 1013*da0073e9SAndroid Build Coastguard Worker # One day we'll implement a batching rule for torch.var_mean. 1014*da0073e9SAndroid Build Coastguard Worker # When that happens, please change the example to use an 1015*da0073e9SAndroid Build Coastguard Worker # operator that doesn't have a batching rule implemented. 1016*da0073e9SAndroid Build Coastguard Worker op_using_fallback = torch.var_mean 1017*da0073e9SAndroid Build Coastguard Worker 1018*da0073e9SAndroid Build Coastguard Worker @self._wrap_method_with_vmap_fallback_check 1019*da0073e9SAndroid Build Coastguard Worker def uses_fallback(self): 1020*da0073e9SAndroid Build Coastguard Worker vmap(op_using_fallback)(torch.rand(3)) 1021*da0073e9SAndroid Build Coastguard Worker 1022*da0073e9SAndroid Build Coastguard Worker no_fallback(self) 1023*da0073e9SAndroid Build Coastguard Worker 1024*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(AssertionError): 1025*da0073e9SAndroid Build Coastguard Worker uses_fallback(self) 1026*da0073e9SAndroid Build Coastguard Worker 1027*da0073e9SAndroid Build Coastguard Worker 1028*da0073e9SAndroid Build Coastguard Workerclass TestVmapOperatorsLegacy(Namespace.TestVmapBaseLegacy): 1029*da0073e9SAndroid Build Coastguard Worker def _vmap_test(self, *args, **kwargs): 1030*da0073e9SAndroid Build Coastguard Worker return _vmap_test(self, *args, **kwargs) 1031*da0073e9SAndroid Build Coastguard Worker 1032*da0073e9SAndroid Build Coastguard Worker def _vmap_view_test(self, *args, **kwargs): 1033*da0073e9SAndroid Build Coastguard Worker self._vmap_test(*args, **kwargs, check_view=True) 1034*da0073e9SAndroid Build Coastguard Worker 1035*da0073e9SAndroid Build Coastguard Worker def _test_unary(self, op, getter, device, *args, **kwargs): 1036*da0073e9SAndroid Build Coastguard Worker test = functools.partial(self._vmap_test, *args, **kwargs) 1037*da0073e9SAndroid Build Coastguard Worker B0, B1 = 7, 11 1038*da0073e9SAndroid Build Coastguard Worker 1039*da0073e9SAndroid Build Coastguard Worker # Single vmap, various in_dims / out_dims 1040*da0073e9SAndroid Build Coastguard Worker test(op, [getter([B0, 3], device)]) 1041*da0073e9SAndroid Build Coastguard Worker test(op, [getter([2, 5, B0, 3], device)], in_dims=2) 1042*da0073e9SAndroid Build Coastguard Worker test(op, [getter([2, 5, B0, 3], device)], in_dims=2, out_dims=2) 1043*da0073e9SAndroid Build Coastguard Worker 1044*da0073e9SAndroid Build Coastguard Worker # Doubly nested vmap 1045*da0073e9SAndroid Build Coastguard Worker test(vmap(op), [getter([B0, B1], device)]) 1046*da0073e9SAndroid Build Coastguard Worker test(vmap(op), [getter([B1, 2, 5, B0, 3], device)], in_dims=2) 1047*da0073e9SAndroid Build Coastguard Worker test( 1048*da0073e9SAndroid Build Coastguard Worker vmap(op, in_dims=2), 1049*da0073e9SAndroid Build Coastguard Worker [getter([2, 5, B0, B1, 3], device)], 1050*da0073e9SAndroid Build Coastguard Worker in_dims=2, 1051*da0073e9SAndroid Build Coastguard Worker out_dims=2, 1052*da0073e9SAndroid Build Coastguard Worker ) 1053*da0073e9SAndroid Build Coastguard Worker 1054*da0073e9SAndroid Build Coastguard Worker def test_unary_pointwise_ops(self): 1055*da0073e9SAndroid Build Coastguard Worker cases = [ 1056*da0073e9SAndroid Build Coastguard Worker (torch.abs, TensorFactory.randn), 1057*da0073e9SAndroid Build Coastguard Worker (torch.acos, TensorFactory.rand), 1058*da0073e9SAndroid Build Coastguard Worker (torch.asin, TensorFactory.rand), 1059*da0073e9SAndroid Build Coastguard Worker (torch.atan, TensorFactory.rand), 1060*da0073e9SAndroid Build Coastguard Worker (torch.ceil, TensorFactory.randn), 1061*da0073e9SAndroid Build Coastguard Worker (torch.cos, TensorFactory.rand), 1062*da0073e9SAndroid Build Coastguard Worker (torch.cosh, TensorFactory.rand), 1063*da0073e9SAndroid Build Coastguard Worker (torch.digamma, TensorFactory.rand), 1064*da0073e9SAndroid Build Coastguard Worker (torch.exp, TensorFactory.randn), 1065*da0073e9SAndroid Build Coastguard Worker (torch.expm1, TensorFactory.randn), 1066*da0073e9SAndroid Build Coastguard Worker (torch.floor, TensorFactory.randn), 1067*da0073e9SAndroid Build Coastguard Worker (torch.frac, TensorFactory.randn), 1068*da0073e9SAndroid Build Coastguard Worker (torch.lgamma, TensorFactory.rand), 1069*da0073e9SAndroid Build Coastguard Worker (torch.log, TensorFactory.randp1), 1070*da0073e9SAndroid Build Coastguard Worker (torch.log10, TensorFactory.randp1), 1071*da0073e9SAndroid Build Coastguard Worker (torch.log1p, TensorFactory.randp1), 1072*da0073e9SAndroid Build Coastguard Worker (torch.log2, TensorFactory.randp1), 1073*da0073e9SAndroid Build Coastguard Worker (torch.neg, TensorFactory.randn), 1074*da0073e9SAndroid Build Coastguard Worker (torch.reciprocal, TensorFactory.randp1), 1075*da0073e9SAndroid Build Coastguard Worker (torch.relu, TensorFactory.randn), 1076*da0073e9SAndroid Build Coastguard Worker (torch.round, TensorFactory.randn), 1077*da0073e9SAndroid Build Coastguard Worker (torch.rsqrt, TensorFactory.randp1), 1078*da0073e9SAndroid Build Coastguard Worker (torch.sigmoid, TensorFactory.randn), 1079*da0073e9SAndroid Build Coastguard Worker (torch.sign, TensorFactory.randn), 1080*da0073e9SAndroid Build Coastguard Worker (torch.sin, TensorFactory.rand), 1081*da0073e9SAndroid Build Coastguard Worker (torch.sinh, TensorFactory.rand), 1082*da0073e9SAndroid Build Coastguard Worker (torch.sqrt, TensorFactory.rand), 1083*da0073e9SAndroid Build Coastguard Worker (torch.tan, TensorFactory.rand), 1084*da0073e9SAndroid Build Coastguard Worker (torch.tanh, TensorFactory.rand), 1085*da0073e9SAndroid Build Coastguard Worker (torch.trunc, TensorFactory.randn), 1086*da0073e9SAndroid Build Coastguard Worker ] 1087*da0073e9SAndroid Build Coastguard Worker for op, getter in cases: 1088*da0073e9SAndroid Build Coastguard Worker self._test_unary(op, getter, "cpu") 1089*da0073e9SAndroid Build Coastguard Worker 1090*da0073e9SAndroid Build Coastguard Worker def test_clone(self): 1091*da0073e9SAndroid Build Coastguard Worker # Some basic tests 1092*da0073e9SAndroid Build Coastguard Worker self._test_unary(lambda x: x.clone(), TensorFactory.randn, "cpu") 1093*da0073e9SAndroid Build Coastguard Worker self._test_unary( 1094*da0073e9SAndroid Build Coastguard Worker lambda x: x.clone(memory_format=torch.preserve_format), 1095*da0073e9SAndroid Build Coastguard Worker TensorFactory.randn, 1096*da0073e9SAndroid Build Coastguard Worker "cpu", 1097*da0073e9SAndroid Build Coastguard Worker ) 1098*da0073e9SAndroid Build Coastguard Worker self._test_unary( 1099*da0073e9SAndroid Build Coastguard Worker lambda x: x.clone(memory_format=torch.contiguous_format), 1100*da0073e9SAndroid Build Coastguard Worker TensorFactory.randn, 1101*da0073e9SAndroid Build Coastguard Worker "cpu", 1102*da0073e9SAndroid Build Coastguard Worker ) 1103*da0073e9SAndroid Build Coastguard Worker 1104*da0073e9SAndroid Build Coastguard Worker # Test that the per-examples are contiguous when using torch.contiguous_format 1105*da0073e9SAndroid Build Coastguard Worker def clone_contiguous(x): 1106*da0073e9SAndroid Build Coastguard Worker return x.clone(memory_format=torch.contiguous_format) 1107*da0073e9SAndroid Build Coastguard Worker 1108*da0073e9SAndroid Build Coastguard Worker B0, B1 = 3, 5 1109*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, B0, 7) 1110*da0073e9SAndroid Build Coastguard Worker y = vmap(clone_contiguous, in_dims=1, out_dims=1)(x) 1111*da0073e9SAndroid Build Coastguard Worker self.assertTrue(y.movedim(1, 0).is_contiguous()) 1112*da0073e9SAndroid Build Coastguard Worker self.assertTrue(y[:, 0, :].is_contiguous()) 1113*da0073e9SAndroid Build Coastguard Worker 1114*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, B0, 7, B1) 1115*da0073e9SAndroid Build Coastguard Worker y = vmap(vmap(clone_contiguous, in_dims=2), in_dims=1)(x) 1116*da0073e9SAndroid Build Coastguard Worker self.assertTrue(y.is_contiguous()) 1117*da0073e9SAndroid Build Coastguard Worker self.assertTrue(y[0][0].is_contiguous()) 1118*da0073e9SAndroid Build Coastguard Worker 1119*da0073e9SAndroid Build Coastguard Worker msg = r"only supported with memory_format torch.preserve_format or torch.contiguous_format" 1120*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, msg): 1121*da0073e9SAndroid Build Coastguard Worker vmap(lambda x: x.clone(memory_format=torch.channels_last))(torch.randn(B0)) 1122*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, msg): 1123*da0073e9SAndroid Build Coastguard Worker vmap(lambda x: x.clone(memory_format=torch.channels_last_3d))( 1124*da0073e9SAndroid Build Coastguard Worker torch.randn(B0) 1125*da0073e9SAndroid Build Coastguard Worker ) 1126*da0073e9SAndroid Build Coastguard Worker 1127*da0073e9SAndroid Build Coastguard Worker def test_binary_pointwise_ops(self): 1128*da0073e9SAndroid Build Coastguard Worker def get_number(getter): 1129*da0073e9SAndroid Build Coastguard Worker return getter([]).item() 1130*da0073e9SAndroid Build Coastguard Worker 1131*da0073e9SAndroid Build Coastguard Worker def make_case(op, input_getter=TensorFactory.randn): 1132*da0073e9SAndroid Build Coastguard Worker return (op, input_getter) 1133*da0073e9SAndroid Build Coastguard Worker 1134*da0073e9SAndroid Build Coastguard Worker cases = [ 1135*da0073e9SAndroid Build Coastguard Worker # Basic arithmetic 1136*da0073e9SAndroid Build Coastguard Worker make_case(torch.add), 1137*da0073e9SAndroid Build Coastguard Worker make_case(lambda x, y: x + y), 1138*da0073e9SAndroid Build Coastguard Worker make_case(torch.sub), 1139*da0073e9SAndroid Build Coastguard Worker make_case(lambda x, y: x - y), 1140*da0073e9SAndroid Build Coastguard Worker make_case(torch.mul), 1141*da0073e9SAndroid Build Coastguard Worker make_case(lambda x, y: x * y), 1142*da0073e9SAndroid Build Coastguard Worker make_case(torch.div, input_getter=TensorFactory.randp1), 1143*da0073e9SAndroid Build Coastguard Worker make_case(lambda x, y: x / y, input_getter=TensorFactory.randp1), 1144*da0073e9SAndroid Build Coastguard Worker make_case(torch.pow, input_getter=TensorFactory.randp1), 1145*da0073e9SAndroid Build Coastguard Worker make_case(lambda x, y: x**y, input_getter=TensorFactory.randp1), 1146*da0073e9SAndroid Build Coastguard Worker ] 1147*da0073e9SAndroid Build Coastguard Worker test = self._vmap_test 1148*da0073e9SAndroid Build Coastguard Worker 1149*da0073e9SAndroid Build Coastguard Worker for op, getter in cases: 1150*da0073e9SAndroid Build Coastguard Worker device = "cpu" 1151*da0073e9SAndroid Build Coastguard Worker B0, B1 = 7, 11 1152*da0073e9SAndroid Build Coastguard Worker 1153*da0073e9SAndroid Build Coastguard Worker # Single vmap: op(Tensor, Tensor) 1154*da0073e9SAndroid Build Coastguard Worker test(op, (getter([B0, 3], device), getter([B0, 3], device))) 1155*da0073e9SAndroid Build Coastguard Worker test(op, (getter([B0], device), getter([B0, 2, 3], device))) 1156*da0073e9SAndroid Build Coastguard Worker test(op, (getter([B0], device), getter([2, B0, 3], device)), in_dims=(0, 1)) 1157*da0073e9SAndroid Build Coastguard Worker test( 1158*da0073e9SAndroid Build Coastguard Worker op, 1159*da0073e9SAndroid Build Coastguard Worker (getter([B0], device), getter([2, B0, 3], device)), 1160*da0073e9SAndroid Build Coastguard Worker in_dims=(0, 1), 1161*da0073e9SAndroid Build Coastguard Worker out_dims=1, 1162*da0073e9SAndroid Build Coastguard Worker ) 1163*da0073e9SAndroid Build Coastguard Worker test(op, (getter([B0], device), getter([2, 3], device)), in_dims=(0, None)) 1164*da0073e9SAndroid Build Coastguard Worker test( 1165*da0073e9SAndroid Build Coastguard Worker op, (getter([2, 3], device), getter([B0, 3], device)), in_dims=(0, None) 1166*da0073e9SAndroid Build Coastguard Worker ) 1167*da0073e9SAndroid Build Coastguard Worker 1168*da0073e9SAndroid Build Coastguard Worker # Nested vmap: op(Tensor, Tensor) 1169*da0073e9SAndroid Build Coastguard Worker test( 1170*da0073e9SAndroid Build Coastguard Worker vmap(op), (getter([B0, B1, 2, 3], device), getter([B0, B1, 3], device)) 1171*da0073e9SAndroid Build Coastguard Worker ) 1172*da0073e9SAndroid Build Coastguard Worker test( 1173*da0073e9SAndroid Build Coastguard Worker vmap(op, in_dims=(None, 0)), 1174*da0073e9SAndroid Build Coastguard Worker (getter([B0, 2, 3], device), getter([B1, 3], device)), 1175*da0073e9SAndroid Build Coastguard Worker in_dims=(0, None), 1176*da0073e9SAndroid Build Coastguard Worker ) 1177*da0073e9SAndroid Build Coastguard Worker 1178*da0073e9SAndroid Build Coastguard Worker # Python number overload: op(Tensor, Number) (and vice-versa) 1179*da0073e9SAndroid Build Coastguard Worker number = get_number(getter) 1180*da0073e9SAndroid Build Coastguard Worker self._test_unary(lambda t: op(t, number), getter, device) 1181*da0073e9SAndroid Build Coastguard Worker number = get_number(getter) 1182*da0073e9SAndroid Build Coastguard Worker self._test_unary(lambda t: op(number, t), getter, device) 1183*da0073e9SAndroid Build Coastguard Worker 1184*da0073e9SAndroid Build Coastguard Worker # Type promotion: op(Logical Scalar Tensor, Logical Scalar Tensor) 1185*da0073e9SAndroid Build Coastguard Worker test(op, (getter([B0], device), getter([B0], device, dtype=torch.double))) 1186*da0073e9SAndroid Build Coastguard Worker test(op, (getter([B0], device, dtype=torch.double), getter([B0], device))) 1187*da0073e9SAndroid Build Coastguard Worker test(op, (getter([B0], device), getter([B0], device))) 1188*da0073e9SAndroid Build Coastguard Worker 1189*da0073e9SAndroid Build Coastguard Worker # Type promotion: op(Tensor, Logical Scalar Tensor) (and vice-versa) 1190*da0073e9SAndroid Build Coastguard Worker test(op, (getter([B0, 2], device), getter([B0], device, torch.double))) 1191*da0073e9SAndroid Build Coastguard Worker test(op, (getter([B0], device, torch.double), getter([B0, 2], device))) 1192*da0073e9SAndroid Build Coastguard Worker 1193*da0073e9SAndroid Build Coastguard Worker if not torch.cuda.is_available(): 1194*da0073e9SAndroid Build Coastguard Worker continue 1195*da0073e9SAndroid Build Coastguard Worker 1196*da0073e9SAndroid Build Coastguard Worker # TODO(rzou): fix the following 1197*da0073e9SAndroid Build Coastguard Worker # # Test cross-device scalars 1198*da0073e9SAndroid Build Coastguard Worker # number = get_number(getter) 1199*da0073e9SAndroid Build Coastguard Worker # self._test_unary(lambda t: op(t, number), getter, device='cuda') 1200*da0073e9SAndroid Build Coastguard Worker # self._test_unary(lambda t: op(number, t), getter, device='cuda') 1201*da0073e9SAndroid Build Coastguard Worker # self._test_unary(lambda t: op(t, torch.tensor(number)), getter, device='cuda') 1202*da0073e9SAndroid Build Coastguard Worker 1203*da0073e9SAndroid Build Coastguard Worker def test_as_strided(self): 1204*da0073e9SAndroid Build Coastguard Worker def _test(sizes, strides, offset, tensor, lambd): 1205*da0073e9SAndroid Build Coastguard Worker result = vmap(lambda t: t.as_strided(sizes, strides, offset))(tensor) 1206*da0073e9SAndroid Build Coastguard Worker expected = vmap(lambd)(tensor) 1207*da0073e9SAndroid Build Coastguard Worker self.assertTrue(result._base is expected._base) 1208*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, expected) 1209*da0073e9SAndroid Build Coastguard Worker 1210*da0073e9SAndroid Build Coastguard Worker # single vmap test 1211*da0073e9SAndroid Build Coastguard Worker B0 = 5 1212*da0073e9SAndroid Build Coastguard Worker tensors = [ 1213*da0073e9SAndroid Build Coastguard Worker # contiguous 1214*da0073e9SAndroid Build Coastguard Worker torch.randn(B0, 2, 3), 1215*da0073e9SAndroid Build Coastguard Worker # non-contiguous 1216*da0073e9SAndroid Build Coastguard Worker torch.randn(B0, 3, 2).transpose(1, 2), 1217*da0073e9SAndroid Build Coastguard Worker # non-zero storage offset 1218*da0073e9SAndroid Build Coastguard Worker torch.randn(2, B0, 2, 3)[1], 1219*da0073e9SAndroid Build Coastguard Worker # non-contiguous strides, zero storage offset 1220*da0073e9SAndroid Build Coastguard Worker torch.randn(B0, 2, 4, 3, 7)[:, :, 0, :, 0], 1221*da0073e9SAndroid Build Coastguard Worker # non-contiguous strides, non-zero storage offset 1222*da0073e9SAndroid Build Coastguard Worker torch.randn(B0, 2, 4, 3, 7)[:, :, 2, :, 1], 1223*da0073e9SAndroid Build Coastguard Worker ] 1224*da0073e9SAndroid Build Coastguard Worker 1225*da0073e9SAndroid Build Coastguard Worker for x in tensors: 1226*da0073e9SAndroid Build Coastguard Worker S0, S1 = x.stride()[1:] 1227*da0073e9SAndroid Build Coastguard Worker offset = x.storage_offset() 1228*da0073e9SAndroid Build Coastguard Worker 1229*da0073e9SAndroid Build Coastguard Worker # Broadcast 1230*da0073e9SAndroid Build Coastguard Worker _test( 1231*da0073e9SAndroid Build Coastguard Worker [5, 5, 2, 3], [0, 0, S0, S1], offset, x, lambda x: x.expand(5, 5, 2, 3) 1232*da0073e9SAndroid Build Coastguard Worker ) 1233*da0073e9SAndroid Build Coastguard Worker # transpose 1234*da0073e9SAndroid Build Coastguard Worker _test([3, 2], [S1, S0], offset, x, lambda x: x.transpose(0, 1)) 1235*da0073e9SAndroid Build Coastguard Worker # select 1236*da0073e9SAndroid Build Coastguard Worker _test([2], [S0], offset + S1, x, lambda x: x[:, 1]) 1237*da0073e9SAndroid Build Coastguard Worker 1238*da0073e9SAndroid Build Coastguard Worker # Nested vmap test 1239*da0073e9SAndroid Build Coastguard Worker B1 = 7 1240*da0073e9SAndroid Build Coastguard Worker x = torch.randn(B1, B0, 2, 3) 1241*da0073e9SAndroid Build Coastguard Worker S0, S1 = x.stride()[2:] 1242*da0073e9SAndroid Build Coastguard Worker result = vmap( 1243*da0073e9SAndroid Build Coastguard Worker vmap(lambda t: t.as_strided([5, 5, 2, 3], [0, 0, S0, S1])), in_dims=1 1244*da0073e9SAndroid Build Coastguard Worker )(x) 1245*da0073e9SAndroid Build Coastguard Worker expected = vmap(vmap(lambda t: t.expand(5, 5, 2, 3)), in_dims=1)(x) 1246*da0073e9SAndroid Build Coastguard Worker self.assertTrue(result._base is expected._base) 1247*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, expected) 1248*da0073e9SAndroid Build Coastguard Worker 1249*da0073e9SAndroid Build Coastguard Worker # Check that mal-formatted size/strides doesn't crash 1250*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1251*da0073e9SAndroid Build Coastguard Worker RuntimeError, "size and stride must have the same length" 1252*da0073e9SAndroid Build Coastguard Worker ): 1253*da0073e9SAndroid Build Coastguard Worker x = torch.randn(B0, 2, 3).transpose(0, 1) 1254*da0073e9SAndroid Build Coastguard Worker vmap(lambda x: x.as_strided([1, 1, 1], [1, 1]))(x) 1255*da0073e9SAndroid Build Coastguard Worker 1256*da0073e9SAndroid Build Coastguard Worker # Sanity check #1: we require the batch dims to be at the front of the 1257*da0073e9SAndroid Build Coastguard Worker # tensor (in memory layout). 1258*da0073e9SAndroid Build Coastguard Worker msg = "batch dims being vmapped over are at the front of the tensor" 1259*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, msg): 1260*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, B0, 3).transpose(0, 1) 1261*da0073e9SAndroid Build Coastguard Worker vmap(lambda x: x.as_strided([2, 3], [B0 * 3, 1]))(x) 1262*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, msg): 1263*da0073e9SAndroid Build Coastguard Worker x = torch.randn(B0, 2, 3, B1).movedim(3, 1) 1264*da0073e9SAndroid Build Coastguard Worker vmap(vmap(lambda x: x.as_strided([2, 3], [B1 * 3, B1])))(x) 1265*da0073e9SAndroid Build Coastguard Worker 1266*da0073e9SAndroid Build Coastguard Worker # All the Sanity check #2{a,b,c} cases check that 1267*da0073e9SAndroid Build Coastguard Worker # xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset()) 1268*da0073e9SAndroid Build Coastguard Worker # doesn't index memory that is out of bounds of xs[i]. This condition 1269*da0073e9SAndroid Build Coastguard Worker # is important to the correctness of the as_strided batching rule 1270*da0073e9SAndroid Build Coastguard Worker # (see NOTE: [When will the as_strided_batching_rule fail?]) 1271*da0073e9SAndroid Build Coastguard Worker 1272*da0073e9SAndroid Build Coastguard Worker # Sanity check #2a: The maximum indexable location of 1273*da0073e9SAndroid Build Coastguard Worker # xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset()) 1274*da0073e9SAndroid Build Coastguard Worker # is less than or equal to the maximum indexable location of xs[i]. 1275*da0073e9SAndroid Build Coastguard Worker msg = "This is not supported inside of vmap" 1276*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, msg): 1277*da0073e9SAndroid Build Coastguard Worker x = torch.randn(B0, 3) 1278*da0073e9SAndroid Build Coastguard Worker vmap(lambda x: x.as_strided([3], [1], 1))(x) 1279*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, msg): 1280*da0073e9SAndroid Build Coastguard Worker x = torch.randn(B0, 3, 5) 1281*da0073e9SAndroid Build Coastguard Worker vmap(lambda x: x.as_strided([4, 4], [4, 1], 0))(x) 1282*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, msg): 1283*da0073e9SAndroid Build Coastguard Worker x = torch.randn(B0, B1, 3, 5) 1284*da0073e9SAndroid Build Coastguard Worker vmap(vmap(lambda x: x.as_strided([4, 4], [4, 1], 0)))(x) 1285*da0073e9SAndroid Build Coastguard Worker 1286*da0073e9SAndroid Build Coastguard Worker # Sanity check #2b: The min indexable location of 1287*da0073e9SAndroid Build Coastguard Worker # xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset()) 1288*da0073e9SAndroid Build Coastguard Worker # is greater than or equal to the min indexable location of xs[i]. 1289*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, msg): 1290*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, B0, 3)[1] 1291*da0073e9SAndroid Build Coastguard Worker vmap(lambda x: x.as_strided([3], [1], B0 * 3 - 1))(x) 1292*da0073e9SAndroid Build Coastguard Worker 1293*da0073e9SAndroid Build Coastguard Worker # Sanity check #2c: 1294*da0073e9SAndroid Build Coastguard Worker # xs[i] is a zero-dim tensor, but 1295*da0073e9SAndroid Build Coastguard Worker # xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset()) 1296*da0073e9SAndroid Build Coastguard Worker # is not 1297*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, msg): 1298*da0073e9SAndroid Build Coastguard Worker x = torch.randn(B0, 0, 3) 1299*da0073e9SAndroid Build Coastguard Worker vmap(lambda x: x.as_strided([3], [1]))(x) 1300*da0073e9SAndroid Build Coastguard Worker 1301*da0073e9SAndroid Build Coastguard Worker def test_bmm(self): 1302*da0073e9SAndroid Build Coastguard Worker op = torch.bmm 1303*da0073e9SAndroid Build Coastguard Worker test = self._vmap_test 1304*da0073e9SAndroid Build Coastguard Worker B0, B1 = 7, 11 1305*da0073e9SAndroid Build Coastguard Worker 1306*da0073e9SAndroid Build Coastguard Worker # shape mismatch 1307*da0073e9SAndroid Build Coastguard Worker msg = "Shape mismatch" 1308*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, msg): 1309*da0073e9SAndroid Build Coastguard Worker vmap(op)(torch.randn(B0, 2, 2, 2), torch.randn(B0, 2)) 1310*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, msg): 1311*da0073e9SAndroid Build Coastguard Worker vmap(op, in_dims=(0, None))(torch.randn(B0, 3, 3, 2), torch.randn(2, 2)) 1312*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, msg): 1313*da0073e9SAndroid Build Coastguard Worker vmap(op, in_dims=(None, 0))(torch.randn(2, 2), torch.randn(B0, 2, 2, 2)) 1314*da0073e9SAndroid Build Coastguard Worker 1315*da0073e9SAndroid Build Coastguard Worker # left arg is vmapped 1316*da0073e9SAndroid Build Coastguard Worker test(op, (torch.rand(B0, 2, 3, 5), torch.rand(2, 5, 3)), in_dims=(0, None)) 1317*da0073e9SAndroid Build Coastguard Worker test( 1318*da0073e9SAndroid Build Coastguard Worker vmap(op, in_dims=(0, None)), 1319*da0073e9SAndroid Build Coastguard Worker (torch.rand(B1, B0, 2, 3, 5), torch.rand(2, 5, 3)), 1320*da0073e9SAndroid Build Coastguard Worker in_dims=(1, None), 1321*da0073e9SAndroid Build Coastguard Worker ) 1322*da0073e9SAndroid Build Coastguard Worker 1323*da0073e9SAndroid Build Coastguard Worker # right arg is vmapped 1324*da0073e9SAndroid Build Coastguard Worker test(op, (torch.rand(2, 5, 3), torch.rand(B0, 2, 3, 5)), in_dims=(None, 0)) 1325*da0073e9SAndroid Build Coastguard Worker test( 1326*da0073e9SAndroid Build Coastguard Worker vmap(op, in_dims=(None, 0)), 1327*da0073e9SAndroid Build Coastguard Worker (torch.rand(2, 5, 3), torch.rand(B1, B0, 2, 3, 5)), 1328*da0073e9SAndroid Build Coastguard Worker in_dims=(None, 1), 1329*da0073e9SAndroid Build Coastguard Worker ) 1330*da0073e9SAndroid Build Coastguard Worker 1331*da0073e9SAndroid Build Coastguard Worker # both args are vmapped 1332*da0073e9SAndroid Build Coastguard Worker test(op, (torch.rand(B0, 2, 3, 5), torch.rand(B0, 2, 5, 3))) 1333*da0073e9SAndroid Build Coastguard Worker test( 1334*da0073e9SAndroid Build Coastguard Worker vmap(op), 1335*da0073e9SAndroid Build Coastguard Worker (torch.rand(B1, B0, 2, 3, 5), torch.rand(B0, B1, 2, 5, 3)), 1336*da0073e9SAndroid Build Coastguard Worker in_dims=(1, 0), 1337*da0073e9SAndroid Build Coastguard Worker ) 1338*da0073e9SAndroid Build Coastguard Worker test( 1339*da0073e9SAndroid Build Coastguard Worker vmap(op, in_dims=(0, None)), 1340*da0073e9SAndroid Build Coastguard Worker (torch.rand(B1, 2, 3, 5), torch.rand(B0, 2, 5, 3)), 1341*da0073e9SAndroid Build Coastguard Worker in_dims=(None, 0), 1342*da0073e9SAndroid Build Coastguard Worker ) 1343*da0073e9SAndroid Build Coastguard Worker 1344*da0073e9SAndroid Build Coastguard Worker def test_cat(self): 1345*da0073e9SAndroid Build Coastguard Worker test = self._vmap_test 1346*da0073e9SAndroid Build Coastguard Worker B0, B1 = 5, 7 1347*da0073e9SAndroid Build Coastguard Worker 1348*da0073e9SAndroid Build Coastguard Worker # Quick hack b/c vmap can't accept a list of tensors as an argument 1349*da0073e9SAndroid Build Coastguard Worker def get_op(dim): 1350*da0073e9SAndroid Build Coastguard Worker def op(*tensors): 1351*da0073e9SAndroid Build Coastguard Worker return torch.cat(tensors, dim=dim) 1352*da0073e9SAndroid Build Coastguard Worker 1353*da0073e9SAndroid Build Coastguard Worker return op 1354*da0073e9SAndroid Build Coastguard Worker 1355*da0073e9SAndroid Build Coastguard Worker test(get_op(0), (torch.rand(B0, 2), torch.rand(B0, 3))) 1356*da0073e9SAndroid Build Coastguard Worker test(get_op(0), (torch.rand(2), torch.rand(B0, 3)), in_dims=(None, 0)) 1357*da0073e9SAndroid Build Coastguard Worker test(get_op(0), (torch.rand(2, 17), torch.rand(3, 17, B0)), in_dims=(None, 2)) 1358*da0073e9SAndroid Build Coastguard Worker test(get_op(-1), (torch.rand(17, 2), torch.rand(17, 3, B0)), in_dims=(None, 2)) 1359*da0073e9SAndroid Build Coastguard Worker test( 1360*da0073e9SAndroid Build Coastguard Worker vmap(get_op(0), in_dims=(0, None)), 1361*da0073e9SAndroid Build Coastguard Worker (torch.rand(B1, 2), torch.rand(B0, 3)), 1362*da0073e9SAndroid Build Coastguard Worker in_dims=(None, 0), 1363*da0073e9SAndroid Build Coastguard Worker ) 1364*da0073e9SAndroid Build Coastguard Worker test( 1365*da0073e9SAndroid Build Coastguard Worker vmap(get_op(0), in_dims=(0, 0)), 1366*da0073e9SAndroid Build Coastguard Worker (torch.rand(B1, 2), torch.rand(B0, B1, 3)), 1367*da0073e9SAndroid Build Coastguard Worker in_dims=(None, 0), 1368*da0073e9SAndroid Build Coastguard Worker ) 1369*da0073e9SAndroid Build Coastguard Worker 1370*da0073e9SAndroid Build Coastguard Worker def test_conj(self): 1371*da0073e9SAndroid Build Coastguard Worker op = torch.conj 1372*da0073e9SAndroid Build Coastguard Worker 1373*da0073e9SAndroid Build Coastguard Worker def run_test(dtype): 1374*da0073e9SAndroid Build Coastguard Worker def get(shape): 1375*da0073e9SAndroid Build Coastguard Worker return torch.randn(shape, dtype=dtype) 1376*da0073e9SAndroid Build Coastguard Worker 1377*da0073e9SAndroid Build Coastguard Worker B0, B1 = 7, 11 1378*da0073e9SAndroid Build Coastguard Worker test = self._vmap_test 1379*da0073e9SAndroid Build Coastguard Worker 1380*da0073e9SAndroid Build Coastguard Worker # Single vmap, various in_dims / out_dims 1381*da0073e9SAndroid Build Coastguard Worker test(op, [get([B0, 3])]) 1382*da0073e9SAndroid Build Coastguard Worker test(op, [get([2, 5, B0, 3])], in_dims=2) 1383*da0073e9SAndroid Build Coastguard Worker test(op, [get([2, 5, B0, 3])], in_dims=2, out_dims=2) 1384*da0073e9SAndroid Build Coastguard Worker 1385*da0073e9SAndroid Build Coastguard Worker # Doubly nested vmap 1386*da0073e9SAndroid Build Coastguard Worker test(vmap(op), [get([B0, B1])]) 1387*da0073e9SAndroid Build Coastguard Worker test(vmap(op), [get([B1, 2, 5, B0, 3])], in_dims=2) 1388*da0073e9SAndroid Build Coastguard Worker test(vmap(op, in_dims=2), [get([2, 5, B0, B1, 3])], in_dims=2, out_dims=2) 1389*da0073e9SAndroid Build Coastguard Worker 1390*da0073e9SAndroid Build Coastguard Worker # correctness tests 1391*da0073e9SAndroid Build Coastguard Worker run_test(torch.float) 1392*da0073e9SAndroid Build Coastguard Worker run_test(torch.cfloat) 1393*da0073e9SAndroid Build Coastguard Worker 1394*da0073e9SAndroid Build Coastguard Worker # check that torch.conj on a non-complex tensor returns the same tensor 1395*da0073e9SAndroid Build Coastguard Worker real_tensor = torch.randn(3) 1396*da0073e9SAndroid Build Coastguard Worker result = vmap(op)(real_tensor) 1397*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.data_ptr(), real_tensor.data_ptr()) 1398*da0073e9SAndroid Build Coastguard Worker 1399*da0073e9SAndroid Build Coastguard Worker def test_contiguous(self): 1400*da0073e9SAndroid Build Coastguard Worker op = Tensor.contiguous 1401*da0073e9SAndroid Build Coastguard Worker 1402*da0073e9SAndroid Build Coastguard Worker self._test_unary(op, TensorFactory.randn, "cpu") 1403*da0073e9SAndroid Build Coastguard Worker 1404*da0073e9SAndroid Build Coastguard Worker # check that contiguous returns the original tensor if the per-examples 1405*da0073e9SAndroid Build Coastguard Worker # are already contiguous 1406*da0073e9SAndroid Build Coastguard Worker B0 = 3 1407*da0073e9SAndroid Build Coastguard Worker x = torch.randn(B0, 2, 5, 7) 1408*da0073e9SAndroid Build Coastguard Worker x = x.movedim(0, 2) 1409*da0073e9SAndroid Build Coastguard Worker result = vmap(Tensor.contiguous, in_dims=2, out_dims=2)(x) 1410*da0073e9SAndroid Build Coastguard Worker self.assertTrue(result is x) 1411*da0073e9SAndroid Build Coastguard Worker 1412*da0073e9SAndroid Build Coastguard Worker msg = "NYI: querying is_contiguous inside of vmap for memory_format" 1413*da0073e9SAndroid Build Coastguard Worker tensor = torch.randn(B0, 3) 1414*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, msg): 1415*da0073e9SAndroid Build Coastguard Worker vmap(functools.partial(op, memory_format=torch.channels_last))(tensor) 1416*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, msg): 1417*da0073e9SAndroid Build Coastguard Worker vmap(functools.partial(op, memory_format=torch.channels_last_3d))(tensor) 1418*da0073e9SAndroid Build Coastguard Worker 1419*da0073e9SAndroid Build Coastguard Worker def test_stride(self): 1420*da0073e9SAndroid Build Coastguard Worker B0 = 3 1421*da0073e9SAndroid Build Coastguard Worker 1422*da0073e9SAndroid Build Coastguard Worker x = torch.randn(B0, 2, 5, 7) 1423*da0073e9SAndroid Build Coastguard Worker 1424*da0073e9SAndroid Build Coastguard Worker def foo(x): 1425*da0073e9SAndroid Build Coastguard Worker assert x.stride() == (7 * 5, 7, 1) 1426*da0073e9SAndroid Build Coastguard Worker return x 1427*da0073e9SAndroid Build Coastguard Worker 1428*da0073e9SAndroid Build Coastguard Worker vmap(foo)(x) 1429*da0073e9SAndroid Build Coastguard Worker 1430*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, B0, 5, 7).movedim(1, 0) 1431*da0073e9SAndroid Build Coastguard Worker 1432*da0073e9SAndroid Build Coastguard Worker def bar(x): 1433*da0073e9SAndroid Build Coastguard Worker assert x.stride() == (7 * 5 * B0, 7, 1) 1434*da0073e9SAndroid Build Coastguard Worker return x 1435*da0073e9SAndroid Build Coastguard Worker 1436*da0073e9SAndroid Build Coastguard Worker vmap(bar)(x) 1437*da0073e9SAndroid Build Coastguard Worker 1438*da0073e9SAndroid Build Coastguard Worker def test_chunk(self): 1439*da0073e9SAndroid Build Coastguard Worker test = self._vmap_view_test 1440*da0073e9SAndroid Build Coastguard Worker op = torch.chunk 1441*da0073e9SAndroid Build Coastguard Worker B0, B1, B2 = 7, 11, 13 1442*da0073e9SAndroid Build Coastguard Worker 1443*da0073e9SAndroid Build Coastguard Worker # tests for torch.split(self, split_size: int, dim) 1444*da0073e9SAndroid Build Coastguard Worker test(op, (torch.rand(B0, 2, 1024), 15, -1), in_dims=(0, None, None)) 1445*da0073e9SAndroid Build Coastguard Worker test(op, (torch.rand(2, B0, 1024), 9, 1), in_dims=(1, None, None)) 1446*da0073e9SAndroid Build Coastguard Worker test( 1447*da0073e9SAndroid Build Coastguard Worker vmap(op, in_dims=(0, None, None)), 1448*da0073e9SAndroid Build Coastguard Worker (torch.rand(B1, 1023, B0, 5), 4, 0), 1449*da0073e9SAndroid Build Coastguard Worker in_dims=(2, None, None), 1450*da0073e9SAndroid Build Coastguard Worker ) 1451*da0073e9SAndroid Build Coastguard Worker test( 1452*da0073e9SAndroid Build Coastguard Worker vmap(vmap(lambda t: op(t, 4, 1), in_dims=2)), 1453*da0073e9SAndroid Build Coastguard Worker (torch.rand(B1, 2, B0, 64, B2),), 1454*da0073e9SAndroid Build Coastguard Worker in_dims=2, 1455*da0073e9SAndroid Build Coastguard Worker ) 1456*da0073e9SAndroid Build Coastguard Worker 1457*da0073e9SAndroid Build Coastguard Worker def test_clamp(self): 1458*da0073e9SAndroid Build Coastguard Worker clamp_cases = ( 1459*da0073e9SAndroid Build Coastguard Worker (lambda t: t.clamp(min=-0.5), TensorFactory.randn), 1460*da0073e9SAndroid Build Coastguard Worker (lambda t: t.clamp(max=0.5), TensorFactory.randn), 1461*da0073e9SAndroid Build Coastguard Worker (lambda t: t.clamp(min=-0.5, max=0.5), TensorFactory.randn), 1462*da0073e9SAndroid Build Coastguard Worker (lambda t: t.clamp_min(min=-0.5), TensorFactory.randn), 1463*da0073e9SAndroid Build Coastguard Worker (lambda t: t.clamp_max(max=0.5), TensorFactory.randn), 1464*da0073e9SAndroid Build Coastguard Worker ) 1465*da0073e9SAndroid Build Coastguard Worker for op, getter in clamp_cases: 1466*da0073e9SAndroid Build Coastguard Worker self._test_unary(op, getter, "cpu") 1467*da0073e9SAndroid Build Coastguard Worker 1468*da0073e9SAndroid Build Coastguard Worker def test_comparison_ops(self): 1469*da0073e9SAndroid Build Coastguard Worker test = functools.partial(self._vmap_test, check_propagates_grad=False) 1470*da0073e9SAndroid Build Coastguard Worker 1471*da0073e9SAndroid Build Coastguard Worker getter = TensorFactory.randn 1472*da0073e9SAndroid Build Coastguard Worker B0, B1 = 7, 11 1473*da0073e9SAndroid Build Coastguard Worker 1474*da0073e9SAndroid Build Coastguard Worker ops = ( 1475*da0073e9SAndroid Build Coastguard Worker torch.eq, 1476*da0073e9SAndroid Build Coastguard Worker lambda x, y: x == y, 1477*da0073e9SAndroid Build Coastguard Worker torch.gt, 1478*da0073e9SAndroid Build Coastguard Worker lambda x, y: x > y, 1479*da0073e9SAndroid Build Coastguard Worker torch.ge, 1480*da0073e9SAndroid Build Coastguard Worker lambda x, y: x >= y, 1481*da0073e9SAndroid Build Coastguard Worker torch.le, 1482*da0073e9SAndroid Build Coastguard Worker lambda x, y: x <= y, 1483*da0073e9SAndroid Build Coastguard Worker torch.lt, 1484*da0073e9SAndroid Build Coastguard Worker lambda x, y: x < y, 1485*da0073e9SAndroid Build Coastguard Worker torch.ne, 1486*da0073e9SAndroid Build Coastguard Worker lambda x, y: x != y, 1487*da0073e9SAndroid Build Coastguard Worker ) 1488*da0073e9SAndroid Build Coastguard Worker 1489*da0073e9SAndroid Build Coastguard Worker for op in ops: 1490*da0073e9SAndroid Build Coastguard Worker # Single vmap: op(Tensor, Tensor) 1491*da0073e9SAndroid Build Coastguard Worker test(op, (getter([B0, 3]), getter([B0, 3]))) 1492*da0073e9SAndroid Build Coastguard Worker test(op, (getter([B0]), getter([B0, 2, 3]))) 1493*da0073e9SAndroid Build Coastguard Worker test(op, (getter([B0]), getter([2, B0, 3])), in_dims=(0, 1)) 1494*da0073e9SAndroid Build Coastguard Worker test(op, (getter([B0]), getter([2, B0, 3])), in_dims=(0, 1), out_dims=1) 1495*da0073e9SAndroid Build Coastguard Worker test(op, (getter([B0]), getter([2, 3])), in_dims=(0, None)) 1496*da0073e9SAndroid Build Coastguard Worker test(op, (getter([2, 3]), getter([B0, 3])), in_dims=(0, None)) 1497*da0073e9SAndroid Build Coastguard Worker 1498*da0073e9SAndroid Build Coastguard Worker # Nested vmap: op(Tensor, Tensor) 1499*da0073e9SAndroid Build Coastguard Worker test(vmap(op), (getter([B0, B1, 2, 3]), getter([B0, B1, 3]))) 1500*da0073e9SAndroid Build Coastguard Worker test( 1501*da0073e9SAndroid Build Coastguard Worker vmap(op, in_dims=(None, 0)), 1502*da0073e9SAndroid Build Coastguard Worker (getter([B0, 2, 3]), getter([B1, 3])), 1503*da0073e9SAndroid Build Coastguard Worker in_dims=(0, None), 1504*da0073e9SAndroid Build Coastguard Worker ) 1505*da0073e9SAndroid Build Coastguard Worker 1506*da0073e9SAndroid Build Coastguard Worker # test number as inputs 1507*da0073e9SAndroid Build Coastguard Worker number = getter([]).item() 1508*da0073e9SAndroid Build Coastguard Worker self._test_unary( 1509*da0073e9SAndroid Build Coastguard Worker lambda t: op(t, number), getter, "cpu", check_propagates_grad=False 1510*da0073e9SAndroid Build Coastguard Worker ) 1511*da0073e9SAndroid Build Coastguard Worker 1512*da0073e9SAndroid Build Coastguard Worker def test_diagonal(self): 1513*da0073e9SAndroid Build Coastguard Worker tensor = torch.randn(3, 5, 7, 11, 13) 1514*da0073e9SAndroid Build Coastguard Worker test = self._vmap_view_test 1515*da0073e9SAndroid Build Coastguard Worker op = torch.diagonal 1516*da0073e9SAndroid Build Coastguard Worker test(op, (tensor, 1, 0, 1), in_dims=(0, None, None, None)) 1517*da0073e9SAndroid Build Coastguard Worker test(op, (tensor, 0, 2, -1), in_dims=(0, None, None, None)) 1518*da0073e9SAndroid Build Coastguard Worker test(op, (tensor, 2, 1, 2), in_dims=(1, None, None, None)) 1519*da0073e9SAndroid Build Coastguard Worker test(op, (tensor, 0, -2, -1), in_dims=(1, None, None, None), out_dims=1) 1520*da0073e9SAndroid Build Coastguard Worker test(vmap(lambda t: op(t, 0, 0, -1)), (tensor,), in_dims=1, out_dims=1) 1521*da0073e9SAndroid Build Coastguard Worker test( 1522*da0073e9SAndroid Build Coastguard Worker vmap(vmap(lambda t: op(t, 0, 0, 1), in_dims=1), in_dims=3), 1523*da0073e9SAndroid Build Coastguard Worker (tensor,), 1524*da0073e9SAndroid Build Coastguard Worker in_dims=1, 1525*da0073e9SAndroid Build Coastguard Worker out_dims=1, 1526*da0073e9SAndroid Build Coastguard Worker ) 1527*da0073e9SAndroid Build Coastguard Worker 1528*da0073e9SAndroid Build Coastguard Worker def test_dot(self): 1529*da0073e9SAndroid Build Coastguard Worker op = torch.dot 1530*da0073e9SAndroid Build Coastguard Worker test = self._vmap_test 1531*da0073e9SAndroid Build Coastguard Worker B0, B1 = 7, 11 1532*da0073e9SAndroid Build Coastguard Worker 1533*da0073e9SAndroid Build Coastguard Worker # shape mismatch 1534*da0073e9SAndroid Build Coastguard Worker msg = "Shape mismatch" 1535*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, msg): 1536*da0073e9SAndroid Build Coastguard Worker vmap(op)(torch.randn(B0, 2, 2, 2), torch.randn(B0, 2)) 1537*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, msg): 1538*da0073e9SAndroid Build Coastguard Worker vmap(op, in_dims=(0, None))(torch.randn(B0, 2), torch.randn(2, 2)) 1539*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, msg): 1540*da0073e9SAndroid Build Coastguard Worker vmap(op, in_dims=(None, 0))(torch.randn(2, 2), torch.randn(B0, 2)) 1541*da0073e9SAndroid Build Coastguard Worker 1542*da0073e9SAndroid Build Coastguard Worker # left arg is vmapped 1543*da0073e9SAndroid Build Coastguard Worker test(op, (torch.rand(B0, 5), torch.rand(5)), in_dims=(0, None)) 1544*da0073e9SAndroid Build Coastguard Worker test( 1545*da0073e9SAndroid Build Coastguard Worker vmap(op, in_dims=(0, None)), 1546*da0073e9SAndroid Build Coastguard Worker (torch.rand(B1, B0, 5), torch.rand(5)), 1547*da0073e9SAndroid Build Coastguard Worker in_dims=(1, None), 1548*da0073e9SAndroid Build Coastguard Worker ) 1549*da0073e9SAndroid Build Coastguard Worker 1550*da0073e9SAndroid Build Coastguard Worker # right arg is vmapped 1551*da0073e9SAndroid Build Coastguard Worker test(op, (torch.rand(5), torch.rand(B0, 5)), in_dims=(None, 0)) 1552*da0073e9SAndroid Build Coastguard Worker test( 1553*da0073e9SAndroid Build Coastguard Worker vmap(op, in_dims=(None, 0)), 1554*da0073e9SAndroid Build Coastguard Worker (torch.rand(5), torch.rand(B1, B0, 5)), 1555*da0073e9SAndroid Build Coastguard Worker in_dims=(None, 1), 1556*da0073e9SAndroid Build Coastguard Worker ) 1557*da0073e9SAndroid Build Coastguard Worker 1558*da0073e9SAndroid Build Coastguard Worker # both args are vmapped 1559*da0073e9SAndroid Build Coastguard Worker test(op, (torch.rand(B0, 5), torch.rand(B0, 5))) 1560*da0073e9SAndroid Build Coastguard Worker test(vmap(op), (torch.rand(B1, B0, 5), torch.rand(B0, B1, 5)), in_dims=(1, 0)) 1561*da0073e9SAndroid Build Coastguard Worker test( 1562*da0073e9SAndroid Build Coastguard Worker vmap(op, in_dims=(0, None)), 1563*da0073e9SAndroid Build Coastguard Worker (torch.rand(B1, 5), torch.rand(B0, 5)), 1564*da0073e9SAndroid Build Coastguard Worker in_dims=(None, 0), 1565*da0073e9SAndroid Build Coastguard Worker ) 1566*da0073e9SAndroid Build Coastguard Worker 1567*da0073e9SAndroid Build Coastguard Worker def test_expand_as(self): 1568*da0073e9SAndroid Build Coastguard Worker op = torch.Tensor.expand_as 1569*da0073e9SAndroid Build Coastguard Worker test = self._vmap_view_test 1570*da0073e9SAndroid Build Coastguard Worker B0, B1, B2 = 7, 11, 13 1571*da0073e9SAndroid Build Coastguard Worker test(op, (torch.rand(B0, 1, 5), torch.rand(B0, 2, 3, 5))) 1572*da0073e9SAndroid Build Coastguard Worker test(op, (torch.rand(B0, 1, 5), torch.rand(2, 3, 5)), in_dims=(0, None)) 1573*da0073e9SAndroid Build Coastguard Worker test(op, (torch.rand(1, 5), torch.rand(B0, 2, 3, 5)), in_dims=(None, 0)) 1574*da0073e9SAndroid Build Coastguard Worker test(vmap(op), (torch.rand(B0, B1, 1, 5), torch.rand(B0, B1, 2, 3, 5))) 1575*da0073e9SAndroid Build Coastguard Worker test( 1576*da0073e9SAndroid Build Coastguard Worker vmap(op), 1577*da0073e9SAndroid Build Coastguard Worker (torch.rand(B0, B1, 1, 5), torch.rand(B1, B0, 2, 3, 5)), 1578*da0073e9SAndroid Build Coastguard Worker in_dims=(0, 1), 1579*da0073e9SAndroid Build Coastguard Worker ) 1580*da0073e9SAndroid Build Coastguard Worker test(vmap(op), (torch.rand(B0, B1), torch.rand(B1, 2, 3, 5)), in_dims=(0, None)) 1581*da0073e9SAndroid Build Coastguard Worker test(vmap(vmap(op)), (torch.rand(B0, B1, B2), torch.rand(B0, B1, B2, 2, 3, 5))) 1582*da0073e9SAndroid Build Coastguard Worker 1583*da0073e9SAndroid Build Coastguard Worker def test_fill_and_zero_inplace(self): 1584*da0073e9SAndroid Build Coastguard Worker test = functools.partial(self._vmap_test, check_propagates_grad=False) 1585*da0073e9SAndroid Build Coastguard Worker B0, B1 = 7, 11 1586*da0073e9SAndroid Build Coastguard Worker ops = ( 1587*da0073e9SAndroid Build Coastguard Worker lambda t: t.fill_(0.1), 1588*da0073e9SAndroid Build Coastguard Worker lambda t: t.fill_(torch.tensor(0.2)), 1589*da0073e9SAndroid Build Coastguard Worker lambda t: t.zero_(), 1590*da0073e9SAndroid Build Coastguard Worker ) 1591*da0073e9SAndroid Build Coastguard Worker 1592*da0073e9SAndroid Build Coastguard Worker for op in ops: 1593*da0073e9SAndroid Build Coastguard Worker # Single vmap, various in_dims / out_dims 1594*da0073e9SAndroid Build Coastguard Worker test(op, [TensorFactory.randn([B0, 3])]) 1595*da0073e9SAndroid Build Coastguard Worker test(op, [TensorFactory.randn([2, 5, B0, 3])], in_dims=2) 1596*da0073e9SAndroid Build Coastguard Worker test(op, [TensorFactory.randn([2, 5, B0, 3])], in_dims=2, out_dims=2) 1597*da0073e9SAndroid Build Coastguard Worker 1598*da0073e9SAndroid Build Coastguard Worker # Doubly nested vmap 1599*da0073e9SAndroid Build Coastguard Worker test(vmap(op), [TensorFactory.randn([B0, B1])]) 1600*da0073e9SAndroid Build Coastguard Worker test(vmap(op), [TensorFactory.randn([B1, 2, 5, B0, 3])], in_dims=2) 1601*da0073e9SAndroid Build Coastguard Worker test( 1602*da0073e9SAndroid Build Coastguard Worker vmap(op, in_dims=2), 1603*da0073e9SAndroid Build Coastguard Worker [TensorFactory.randn([2, 5, B0, B1, 3])], 1604*da0073e9SAndroid Build Coastguard Worker in_dims=2, 1605*da0073e9SAndroid Build Coastguard Worker out_dims=2, 1606*da0073e9SAndroid Build Coastguard Worker ) 1607*da0073e9SAndroid Build Coastguard Worker 1608*da0073e9SAndroid Build Coastguard Worker # test when value is a batched tensor for fill_ operator 1609*da0073e9SAndroid Build Coastguard Worker B0, B1 = 3, 5 1610*da0073e9SAndroid Build Coastguard Worker test(Tensor.fill_, [TensorFactory.randn([B0, B1]), TensorFactory.randn(B0)]) 1611*da0073e9SAndroid Build Coastguard Worker 1612*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1613*da0073e9SAndroid Build Coastguard Worker RuntimeError, r"output with shape .+ doesn't match the broadcast shape" 1614*da0073e9SAndroid Build Coastguard Worker ): 1615*da0073e9SAndroid Build Coastguard Worker # Runtime Error is thrown when the tensor being written to isn't being vmapped over 1616*da0073e9SAndroid Build Coastguard Worker vmap(Tensor.fill_, (None, 0))( 1617*da0073e9SAndroid Build Coastguard Worker TensorFactory.randn([B0, B1]), TensorFactory.randn([B0]) 1618*da0073e9SAndroid Build Coastguard Worker ) 1619*da0073e9SAndroid Build Coastguard Worker 1620*da0073e9SAndroid Build Coastguard Worker def _test_complex_views(self, op, dtypes): 1621*da0073e9SAndroid Build Coastguard Worker test = self._vmap_view_test 1622*da0073e9SAndroid Build Coastguard Worker 1623*da0073e9SAndroid Build Coastguard Worker def run_test(op, dtype): 1624*da0073e9SAndroid Build Coastguard Worker def get(shape): 1625*da0073e9SAndroid Build Coastguard Worker return torch.randn(shape, dtype=dtype) 1626*da0073e9SAndroid Build Coastguard Worker 1627*da0073e9SAndroid Build Coastguard Worker B0, B1 = 7, 11 1628*da0073e9SAndroid Build Coastguard Worker 1629*da0073e9SAndroid Build Coastguard Worker # Single vmap, various in_dims / out_dims 1630*da0073e9SAndroid Build Coastguard Worker test(op, [get([B0, 3])]) 1631*da0073e9SAndroid Build Coastguard Worker test(op, [get([3, B0])], in_dims=1) 1632*da0073e9SAndroid Build Coastguard Worker test(op, [get([2, 5, B0, 3])], in_dims=2) 1633*da0073e9SAndroid Build Coastguard Worker test(op, [get([2, 5, B0, 3])], in_dims=2, out_dims=2) 1634*da0073e9SAndroid Build Coastguard Worker 1635*da0073e9SAndroid Build Coastguard Worker # Doubly nested vmap 1636*da0073e9SAndroid Build Coastguard Worker test(vmap(op), [get([B0, B1])]) 1637*da0073e9SAndroid Build Coastguard Worker test(vmap(op), [get([B1, 2, 5, 3, B0])], in_dims=4) 1638*da0073e9SAndroid Build Coastguard Worker test(vmap(op, in_dims=2), [get([2, 5, B0, B1, 3])], in_dims=2, out_dims=2) 1639*da0073e9SAndroid Build Coastguard Worker 1640*da0073e9SAndroid Build Coastguard Worker for dtype in dtypes: 1641*da0073e9SAndroid Build Coastguard Worker run_test(op, dtype) 1642*da0073e9SAndroid Build Coastguard Worker 1643*da0073e9SAndroid Build Coastguard Worker def test_real(self): 1644*da0073e9SAndroid Build Coastguard Worker self._test_complex_views(torch.real, dtypes=[torch.cfloat, torch.cdouble]) 1645*da0073e9SAndroid Build Coastguard Worker 1646*da0073e9SAndroid Build Coastguard Worker def test_imag(self): 1647*da0073e9SAndroid Build Coastguard Worker self._test_complex_views(torch.imag, dtypes=[torch.cfloat, torch.cdouble]) 1648*da0073e9SAndroid Build Coastguard Worker 1649*da0073e9SAndroid Build Coastguard Worker def test_view_as_real(self): 1650*da0073e9SAndroid Build Coastguard Worker self._test_complex_views( 1651*da0073e9SAndroid Build Coastguard Worker torch.view_as_real, dtypes=[torch.cfloat, torch.cdouble] 1652*da0073e9SAndroid Build Coastguard Worker ) 1653*da0073e9SAndroid Build Coastguard Worker 1654*da0073e9SAndroid Build Coastguard Worker def test_view_as_complex(self): 1655*da0073e9SAndroid Build Coastguard Worker def run_test(dtype): 1656*da0073e9SAndroid Build Coastguard Worker def get(shape): 1657*da0073e9SAndroid Build Coastguard Worker return torch.randn(shape, dtype=dtype) 1658*da0073e9SAndroid Build Coastguard Worker 1659*da0073e9SAndroid Build Coastguard Worker op = torch.view_as_complex 1660*da0073e9SAndroid Build Coastguard Worker test = self._vmap_view_test 1661*da0073e9SAndroid Build Coastguard Worker B0, B1 = 7, 11 1662*da0073e9SAndroid Build Coastguard Worker 1663*da0073e9SAndroid Build Coastguard Worker # Single vmap, various in_dims / out_dims 1664*da0073e9SAndroid Build Coastguard Worker test(op, [get([B0, 3, 2])]) 1665*da0073e9SAndroid Build Coastguard Worker test(op, [get([2, 5, B0, 3, 2])], in_dims=2) 1666*da0073e9SAndroid Build Coastguard Worker test(op, [get([2, 5, B0, 3, 2])], in_dims=2, out_dims=2) 1667*da0073e9SAndroid Build Coastguard Worker 1668*da0073e9SAndroid Build Coastguard Worker # Doubly nested vmap 1669*da0073e9SAndroid Build Coastguard Worker test(vmap(op), [get([B0, B1, 2])]) 1670*da0073e9SAndroid Build Coastguard Worker test(vmap(op), [get([B1, 2, 5, B0, 3, 2])], in_dims=2) 1671*da0073e9SAndroid Build Coastguard Worker test( 1672*da0073e9SAndroid Build Coastguard Worker vmap(op, in_dims=2), [get([2, 5, B0, B1, 3, 2])], in_dims=2, out_dims=2 1673*da0073e9SAndroid Build Coastguard Worker ) 1674*da0073e9SAndroid Build Coastguard Worker 1675*da0073e9SAndroid Build Coastguard Worker # Interesting case #1: Batch dim directly before dim of size 2 1676*da0073e9SAndroid Build Coastguard Worker test(op, [get([3, B0, 2])], in_dims=1) 1677*da0073e9SAndroid Build Coastguard Worker test(vmap(op, in_dims=1), [get([3, B1, B0, 2])], in_dims=2) 1678*da0073e9SAndroid Build Coastguard Worker 1679*da0073e9SAndroid Build Coastguard Worker # Interesting case #2: Batch dim at end of tensor, success cases 1680*da0073e9SAndroid Build Coastguard Worker # view_as_complex requires that the dim with size 2 have stride 1 1681*da0073e9SAndroid Build Coastguard Worker # in order for the view to function propertly 1682*da0073e9SAndroid Build Coastguard Worker test(op, [get([B0, 2]).transpose(0, 1)], in_dims=1) 1683*da0073e9SAndroid Build Coastguard Worker test(vmap(op, in_dims=1), [get([B0, B1, 2]).movedim(1, 2)]) 1684*da0073e9SAndroid Build Coastguard Worker test(vmap(op, in_dims=2), [get([B0, 3, B1, 2]).movedim(2, 3)]) 1685*da0073e9SAndroid Build Coastguard Worker 1686*da0073e9SAndroid Build Coastguard Worker # Interesting case #3: Batch dim at end of tensor, failure cases 1687*da0073e9SAndroid Build Coastguard Worker msg = "Tensor must have a last dimension with stride 1" 1688*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, msg): 1689*da0073e9SAndroid Build Coastguard Worker vmap(op, in_dims=1)(get([2, B0])) 1690*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, msg): 1691*da0073e9SAndroid Build Coastguard Worker vmap(vmap(op, in_dims=1), in_dims=1)(get([2, B0, B1])) 1692*da0073e9SAndroid Build Coastguard Worker 1693*da0073e9SAndroid Build Coastguard Worker # Invalid input: no dimension of size 2 1694*da0073e9SAndroid Build Coastguard Worker msg = "Input tensor must have one or more dimensions" 1695*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, msg): 1696*da0073e9SAndroid Build Coastguard Worker vmap(op)(get([B0])) 1697*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, msg): 1698*da0073e9SAndroid Build Coastguard Worker vmap(vmap(op))(get([B0, B1])) 1699*da0073e9SAndroid Build Coastguard Worker 1700*da0073e9SAndroid Build Coastguard Worker # Invalid input: Batch dim has size 2, but the logical last dim does 1701*da0073e9SAndroid Build Coastguard Worker # not have size 2 1702*da0073e9SAndroid Build Coastguard Worker msg = "Tensor must have a last dimension of size 2" 1703*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, msg): 1704*da0073e9SAndroid Build Coastguard Worker vmap(op, in_dims=1)(get([3, 2])) 1705*da0073e9SAndroid Build Coastguard Worker 1706*da0073e9SAndroid Build Coastguard Worker for dtype in [torch.float, torch.double]: 1707*da0073e9SAndroid Build Coastguard Worker run_test(dtype) 1708*da0073e9SAndroid Build Coastguard Worker 1709*da0073e9SAndroid Build Coastguard Worker def test_is_complex(self): 1710*da0073e9SAndroid Build Coastguard Worker ctensor = torch.randn(3, dtype=torch.cfloat) 1711*da0073e9SAndroid Build Coastguard Worker tensor = torch.randn(3) 1712*da0073e9SAndroid Build Coastguard Worker 1713*da0073e9SAndroid Build Coastguard Worker def foo(x): 1714*da0073e9SAndroid Build Coastguard Worker if x.is_complex(): 1715*da0073e9SAndroid Build Coastguard Worker return torch.tensor(1) 1716*da0073e9SAndroid Build Coastguard Worker else: 1717*da0073e9SAndroid Build Coastguard Worker return torch.tensor(0) 1718*da0073e9SAndroid Build Coastguard Worker 1719*da0073e9SAndroid Build Coastguard Worker self.assertEqual(vmap(foo)(ctensor), torch.tensor([1, 1, 1])) 1720*da0073e9SAndroid Build Coastguard Worker self.assertEqual(vmap(foo)(tensor), torch.tensor([0, 0, 0])) 1721*da0073e9SAndroid Build Coastguard Worker 1722*da0073e9SAndroid Build Coastguard Worker def test_is_floating_point(self): 1723*da0073e9SAndroid Build Coastguard Worker float_tensor = torch.tensor([1.0, 2.0, 3.0]) 1724*da0073e9SAndroid Build Coastguard Worker long_tensor = torch.tensor([1, 2, 3]) 1725*da0073e9SAndroid Build Coastguard Worker 1726*da0073e9SAndroid Build Coastguard Worker def foo(x): 1727*da0073e9SAndroid Build Coastguard Worker if x.is_floating_point(): 1728*da0073e9SAndroid Build Coastguard Worker return torch.tensor(1) 1729*da0073e9SAndroid Build Coastguard Worker else: 1730*da0073e9SAndroid Build Coastguard Worker return torch.tensor(0) 1731*da0073e9SAndroid Build Coastguard Worker 1732*da0073e9SAndroid Build Coastguard Worker self.assertEqual(vmap(foo)(float_tensor), torch.tensor([1, 1, 1])) 1733*da0073e9SAndroid Build Coastguard Worker self.assertEqual(vmap(foo)(long_tensor), torch.tensor([0, 0, 0])) 1734*da0073e9SAndroid Build Coastguard Worker 1735*da0073e9SAndroid Build Coastguard Worker def test_is_contiguous(self): 1736*da0073e9SAndroid Build Coastguard Worker def foo(x): 1737*da0073e9SAndroid Build Coastguard Worker if x.is_contiguous(): 1738*da0073e9SAndroid Build Coastguard Worker return torch.tensor(1.0) 1739*da0073e9SAndroid Build Coastguard Worker else: 1740*da0073e9SAndroid Build Coastguard Worker return torch.tensor(0.0) 1741*da0073e9SAndroid Build Coastguard Worker 1742*da0073e9SAndroid Build Coastguard Worker B0, B1 = 3, 5 1743*da0073e9SAndroid Build Coastguard Worker 1744*da0073e9SAndroid Build Coastguard Worker # Single batch dim 1745*da0073e9SAndroid Build Coastguard Worker contig = torch.randn(B0, 2, 7) 1746*da0073e9SAndroid Build Coastguard Worker self.assertEqual(vmap(foo)(contig), torch.ones(B0)) 1747*da0073e9SAndroid Build Coastguard Worker 1748*da0073e9SAndroid Build Coastguard Worker noncontig = torch.randn(2, B0, 7) 1749*da0073e9SAndroid Build Coastguard Worker self.assertEqual(vmap(foo, in_dims=1)(noncontig), torch.zeros(B0)) 1750*da0073e9SAndroid Build Coastguard Worker 1751*da0073e9SAndroid Build Coastguard Worker noncontig = torch.randn(2, B0, 7).movedim(1, 0) 1752*da0073e9SAndroid Build Coastguard Worker self.assertEqual(vmap(foo)(noncontig), torch.zeros(B0)) 1753*da0073e9SAndroid Build Coastguard Worker 1754*da0073e9SAndroid Build Coastguard Worker noncontig = torch.randn(2, 7, B0) 1755*da0073e9SAndroid Build Coastguard Worker self.assertEqual(vmap(foo, in_dims=2)(noncontig), torch.zeros(B0)) 1756*da0073e9SAndroid Build Coastguard Worker 1757*da0073e9SAndroid Build Coastguard Worker # Multiple batch dims 1758*da0073e9SAndroid Build Coastguard Worker contig = torch.randn(B0, B1, 3) 1759*da0073e9SAndroid Build Coastguard Worker self.assertEqual(vmap(vmap(foo))(contig), torch.ones(B0, B1)) 1760*da0073e9SAndroid Build Coastguard Worker 1761*da0073e9SAndroid Build Coastguard Worker contig = torch.randn(B1, B0, 3) 1762*da0073e9SAndroid Build Coastguard Worker self.assertEqual(vmap(vmap(foo), in_dims=1)(contig), torch.ones(B0, B1)) 1763*da0073e9SAndroid Build Coastguard Worker 1764*da0073e9SAndroid Build Coastguard Worker contig = torch.randn(B1, B0, 3).movedim(0, 1) 1765*da0073e9SAndroid Build Coastguard Worker self.assertEqual(vmap(vmap(foo))(contig), torch.ones(B0, B1)) 1766*da0073e9SAndroid Build Coastguard Worker 1767*da0073e9SAndroid Build Coastguard Worker noncontig = torch.randn(B0, 3, B1) 1768*da0073e9SAndroid Build Coastguard Worker self.assertEqual(vmap(vmap(foo, in_dims=1))(noncontig), torch.zeros(B0, B1)) 1769*da0073e9SAndroid Build Coastguard Worker 1770*da0073e9SAndroid Build Coastguard Worker # is_contiguous on empty tensor is True 1771*da0073e9SAndroid Build Coastguard Worker def bar(x): 1772*da0073e9SAndroid Build Coastguard Worker assert x.is_contiguous() 1773*da0073e9SAndroid Build Coastguard Worker return x 1774*da0073e9SAndroid Build Coastguard Worker 1775*da0073e9SAndroid Build Coastguard Worker vmap(bar)(torch.randn(B0, 0, 3)) 1776*da0073e9SAndroid Build Coastguard Worker vmap(bar, in_dims=1)(torch.randn(0, B0, 3)) 1777*da0073e9SAndroid Build Coastguard Worker vmap(bar)(torch.randn(B0, 0, 3).mT) 1778*da0073e9SAndroid Build Coastguard Worker 1779*da0073e9SAndroid Build Coastguard Worker # is_contiguous with other memory formats 1780*da0073e9SAndroid Build Coastguard Worker def baz(x, memory_format): 1781*da0073e9SAndroid Build Coastguard Worker x.is_contiguous(memory_format=memory_format) 1782*da0073e9SAndroid Build Coastguard Worker return x 1783*da0073e9SAndroid Build Coastguard Worker 1784*da0073e9SAndroid Build Coastguard Worker msg = "NYI: querying is_contiguous inside of vmap for memory_format" 1785*da0073e9SAndroid Build Coastguard Worker tensor = torch.randn(B0, 2, 7, 3) 1786*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, msg): 1787*da0073e9SAndroid Build Coastguard Worker vmap(functools.partial(baz, memory_format=torch.channels_last))(tensor) 1788*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, msg): 1789*da0073e9SAndroid Build Coastguard Worker vmap(functools.partial(baz, memory_format=torch.channels_last_3d))(tensor) 1790*da0073e9SAndroid Build Coastguard Worker 1791*da0073e9SAndroid Build Coastguard Worker def test_movedim(self): 1792*da0073e9SAndroid Build Coastguard Worker op = torch.movedim 1793*da0073e9SAndroid Build Coastguard Worker test = self._vmap_view_test 1794*da0073e9SAndroid Build Coastguard Worker B0, B1, B2 = 7, 11, 13 1795*da0073e9SAndroid Build Coastguard Worker 1796*da0073e9SAndroid Build Coastguard Worker # movedim(tensor, int, int) variant 1797*da0073e9SAndroid Build Coastguard Worker test(op, (torch.rand(B0, 2, 5), 0, 1), in_dims=(0, None, None)) 1798*da0073e9SAndroid Build Coastguard Worker test(op, (torch.rand(2, B0, 5), 0, 1), in_dims=(1, None, None)) 1799*da0073e9SAndroid Build Coastguard Worker test( 1800*da0073e9SAndroid Build Coastguard Worker vmap(op, in_dims=(0, None, None)), 1801*da0073e9SAndroid Build Coastguard Worker (torch.rand(B1, 2, B0, 5), 0, 1), 1802*da0073e9SAndroid Build Coastguard Worker in_dims=(2, None, None), 1803*da0073e9SAndroid Build Coastguard Worker ) 1804*da0073e9SAndroid Build Coastguard Worker test( 1805*da0073e9SAndroid Build Coastguard Worker vmap(vmap(op, in_dims=(2, None, None)), in_dims=(0, None, None)), 1806*da0073e9SAndroid Build Coastguard Worker (torch.rand(B1, 2, B0, 5, B2), 0, 1), 1807*da0073e9SAndroid Build Coastguard Worker in_dims=(2, None, None), 1808*da0073e9SAndroid Build Coastguard Worker ) 1809*da0073e9SAndroid Build Coastguard Worker 1810*da0073e9SAndroid Build Coastguard Worker # movedim(tensor, intlist, intlist) variant 1811*da0073e9SAndroid Build Coastguard Worker test(op, (torch.rand(B0, 2, 3, 5), [1, 0], [0, 2]), in_dims=(0, None, None)) 1812*da0073e9SAndroid Build Coastguard Worker test(op, (torch.rand(2, 3, B0, 5), [1, 0], [0, 2]), in_dims=(1, None, None)) 1813*da0073e9SAndroid Build Coastguard Worker test( 1814*da0073e9SAndroid Build Coastguard Worker vmap(op, in_dims=(0, None, None)), 1815*da0073e9SAndroid Build Coastguard Worker (torch.rand(B1, 2, B0, 5), [0, 1], [1, 0]), 1816*da0073e9SAndroid Build Coastguard Worker in_dims=(2, None, None), 1817*da0073e9SAndroid Build Coastguard Worker ) 1818*da0073e9SAndroid Build Coastguard Worker test( 1819*da0073e9SAndroid Build Coastguard Worker vmap(vmap(op, in_dims=(2, None, None)), in_dims=(0, None, None)), 1820*da0073e9SAndroid Build Coastguard Worker (torch.rand(B1, 2, B0, 5, B2), [0, 1], [1, 0]), 1821*da0073e9SAndroid Build Coastguard Worker in_dims=(2, None, None), 1822*da0073e9SAndroid Build Coastguard Worker ) 1823*da0073e9SAndroid Build Coastguard Worker 1824*da0073e9SAndroid Build Coastguard Worker def test_mm(self): 1825*da0073e9SAndroid Build Coastguard Worker op = torch.mm 1826*da0073e9SAndroid Build Coastguard Worker test = self._vmap_test 1827*da0073e9SAndroid Build Coastguard Worker B0, B1 = 7, 11 1828*da0073e9SAndroid Build Coastguard Worker 1829*da0073e9SAndroid Build Coastguard Worker # shape mismatch 1830*da0073e9SAndroid Build Coastguard Worker msg = "Shape mismatch" 1831*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, msg): 1832*da0073e9SAndroid Build Coastguard Worker vmap(op)(torch.randn(B0, 2, 2, 2), torch.randn(B0, 2)) 1833*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, msg): 1834*da0073e9SAndroid Build Coastguard Worker vmap(op, in_dims=(0, None))(torch.randn(B0, 2), torch.randn(2, 2)) 1835*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, msg): 1836*da0073e9SAndroid Build Coastguard Worker vmap(op, in_dims=(None, 0))(torch.randn(2, 2), torch.randn(B0, 2, 2, 2)) 1837*da0073e9SAndroid Build Coastguard Worker 1838*da0073e9SAndroid Build Coastguard Worker # left arg is vmapped 1839*da0073e9SAndroid Build Coastguard Worker test(op, (torch.rand(B0, 2, 5), torch.rand(5, 2)), in_dims=(0, None)) 1840*da0073e9SAndroid Build Coastguard Worker test( 1841*da0073e9SAndroid Build Coastguard Worker vmap(op, in_dims=(0, None)), 1842*da0073e9SAndroid Build Coastguard Worker (torch.rand(B1, B0, 2, 5), torch.rand(5, 2)), 1843*da0073e9SAndroid Build Coastguard Worker in_dims=(1, None), 1844*da0073e9SAndroid Build Coastguard Worker ) 1845*da0073e9SAndroid Build Coastguard Worker 1846*da0073e9SAndroid Build Coastguard Worker # right arg is vmapped 1847*da0073e9SAndroid Build Coastguard Worker test(op, (torch.rand(2, 5), torch.rand(B0, 5, 2)), in_dims=(None, 0)) 1848*da0073e9SAndroid Build Coastguard Worker test( 1849*da0073e9SAndroid Build Coastguard Worker vmap(op, in_dims=(None, 0)), 1850*da0073e9SAndroid Build Coastguard Worker (torch.rand(2, 5), torch.rand(B1, B0, 5, 2)), 1851*da0073e9SAndroid Build Coastguard Worker in_dims=(None, 1), 1852*da0073e9SAndroid Build Coastguard Worker ) 1853*da0073e9SAndroid Build Coastguard Worker 1854*da0073e9SAndroid Build Coastguard Worker # both args are vmapped 1855*da0073e9SAndroid Build Coastguard Worker test(op, (torch.rand(B0, 2, 5), torch.rand(B0, 5, 2))) 1856*da0073e9SAndroid Build Coastguard Worker test( 1857*da0073e9SAndroid Build Coastguard Worker vmap(op), 1858*da0073e9SAndroid Build Coastguard Worker (torch.rand(B1, B0, 2, 5), torch.rand(B0, B1, 5, 2)), 1859*da0073e9SAndroid Build Coastguard Worker in_dims=(1, 0), 1860*da0073e9SAndroid Build Coastguard Worker ) 1861*da0073e9SAndroid Build Coastguard Worker test( 1862*da0073e9SAndroid Build Coastguard Worker vmap(op, in_dims=(0, None)), 1863*da0073e9SAndroid Build Coastguard Worker (torch.rand(B1, 2, 5), torch.rand(B0, 5, 2)), 1864*da0073e9SAndroid Build Coastguard Worker in_dims=(None, 0), 1865*da0073e9SAndroid Build Coastguard Worker ) 1866*da0073e9SAndroid Build Coastguard Worker 1867*da0073e9SAndroid Build Coastguard Worker def test_mv(self): 1868*da0073e9SAndroid Build Coastguard Worker op = torch.mv 1869*da0073e9SAndroid Build Coastguard Worker test = self._vmap_test 1870*da0073e9SAndroid Build Coastguard Worker B0, B1 = 7, 11 1871*da0073e9SAndroid Build Coastguard Worker 1872*da0073e9SAndroid Build Coastguard Worker # shape mismatch 1873*da0073e9SAndroid Build Coastguard Worker msg = "Shape mismatch" 1874*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, msg): 1875*da0073e9SAndroid Build Coastguard Worker vmap(op)(torch.randn(B0, 2, 2, 2), torch.randn(B0, 2)) 1876*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, msg): 1877*da0073e9SAndroid Build Coastguard Worker vmap(op, in_dims=(0, None))(torch.randn(B0, 2, 2), torch.randn(2, 2)) 1878*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, msg): 1879*da0073e9SAndroid Build Coastguard Worker vmap(op, in_dims=(None, 0))(torch.randn(2, 2), torch.randn(B0, 2, 2)) 1880*da0073e9SAndroid Build Coastguard Worker 1881*da0073e9SAndroid Build Coastguard Worker # left arg is vmapped 1882*da0073e9SAndroid Build Coastguard Worker test(op, (torch.rand(B0, 2, 5), torch.rand(5)), in_dims=(0, None)) 1883*da0073e9SAndroid Build Coastguard Worker test( 1884*da0073e9SAndroid Build Coastguard Worker vmap(op, in_dims=(0, None)), 1885*da0073e9SAndroid Build Coastguard Worker (torch.rand(B1, B0, 2, 5), torch.rand(5)), 1886*da0073e9SAndroid Build Coastguard Worker in_dims=(1, None), 1887*da0073e9SAndroid Build Coastguard Worker ) 1888*da0073e9SAndroid Build Coastguard Worker 1889*da0073e9SAndroid Build Coastguard Worker # right arg is vmapped 1890*da0073e9SAndroid Build Coastguard Worker test(op, (torch.rand(2, 5), torch.rand(B0, 5)), in_dims=(None, 0)) 1891*da0073e9SAndroid Build Coastguard Worker test( 1892*da0073e9SAndroid Build Coastguard Worker vmap(op, in_dims=(None, 0)), 1893*da0073e9SAndroid Build Coastguard Worker (torch.rand(2, 5), torch.rand(B1, B0, 5)), 1894*da0073e9SAndroid Build Coastguard Worker in_dims=(None, 1), 1895*da0073e9SAndroid Build Coastguard Worker ) 1896*da0073e9SAndroid Build Coastguard Worker 1897*da0073e9SAndroid Build Coastguard Worker # both args are vmapped 1898*da0073e9SAndroid Build Coastguard Worker test(op, (torch.rand(B0, 2, 5), torch.rand(B0, 5))) 1899*da0073e9SAndroid Build Coastguard Worker test( 1900*da0073e9SAndroid Build Coastguard Worker vmap(op), (torch.rand(B1, B0, 2, 5), torch.rand(B0, B1, 5)), in_dims=(1, 0) 1901*da0073e9SAndroid Build Coastguard Worker ) 1902*da0073e9SAndroid Build Coastguard Worker test( 1903*da0073e9SAndroid Build Coastguard Worker vmap(op, in_dims=(0, None)), 1904*da0073e9SAndroid Build Coastguard Worker (torch.rand(B1, 2, 5), torch.rand(B0, 5)), 1905*da0073e9SAndroid Build Coastguard Worker in_dims=(None, 0), 1906*da0073e9SAndroid Build Coastguard Worker ) 1907*da0073e9SAndroid Build Coastguard Worker 1908*da0073e9SAndroid Build Coastguard Worker def test_narrow(self): 1909*da0073e9SAndroid Build Coastguard Worker op = torch.narrow 1910*da0073e9SAndroid Build Coastguard Worker test = self._vmap_view_test 1911*da0073e9SAndroid Build Coastguard Worker B0, B1, B2 = 7, 11, 13 1912*da0073e9SAndroid Build Coastguard Worker 1913*da0073e9SAndroid Build Coastguard Worker test(op, (torch.rand(B0, 2, 5), -1, 1, 3), in_dims=(0, None, None, None)) 1914*da0073e9SAndroid Build Coastguard Worker test(op, (torch.rand(2, B0, 5), 1, 1, 3), in_dims=(1, None, None, None)) 1915*da0073e9SAndroid Build Coastguard Worker test( 1916*da0073e9SAndroid Build Coastguard Worker vmap(op, in_dims=(0, None, None, None)), 1917*da0073e9SAndroid Build Coastguard Worker (torch.rand(B1, 2, B0, 5), 1, 0, 0), 1918*da0073e9SAndroid Build Coastguard Worker in_dims=(2, None, None, None), 1919*da0073e9SAndroid Build Coastguard Worker ) 1920*da0073e9SAndroid Build Coastguard Worker test( 1921*da0073e9SAndroid Build Coastguard Worker vmap( 1922*da0073e9SAndroid Build Coastguard Worker vmap(op, in_dims=(2, None, None, None)), in_dims=(0, None, None, None) 1923*da0073e9SAndroid Build Coastguard Worker ), 1924*da0073e9SAndroid Build Coastguard Worker (torch.rand(B1, 2, B0, 5, B2), -1, 2, 3), 1925*da0073e9SAndroid Build Coastguard Worker in_dims=(2, None, None, None), 1926*da0073e9SAndroid Build Coastguard Worker ) 1927*da0073e9SAndroid Build Coastguard Worker 1928*da0073e9SAndroid Build Coastguard Worker def test_new_empty(self): 1929*da0073e9SAndroid Build Coastguard Worker # Empty is non-deterministic so we just check that the shape of the 1930*da0073e9SAndroid Build Coastguard Worker # output tensor is what we expect and that the vmap fallback isn't used. 1931*da0073e9SAndroid Build Coastguard Worker op = Tensor.new_empty 1932*da0073e9SAndroid Build Coastguard Worker 1933*da0073e9SAndroid Build Coastguard Worker B0, B1 = 7, 11 1934*da0073e9SAndroid Build Coastguard Worker 1935*da0073e9SAndroid Build Coastguard Worker result = vmap(lambda x: op(x, [2, 3]))(torch.randn(B0)) 1936*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.shape, [B0, 2, 3]) 1937*da0073e9SAndroid Build Coastguard Worker 1938*da0073e9SAndroid Build Coastguard Worker result = vmap(lambda x: op(x, []))(torch.randn(B0)) 1939*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.shape, [B0]) 1940*da0073e9SAndroid Build Coastguard Worker 1941*da0073e9SAndroid Build Coastguard Worker result = vmap(vmap(lambda x: op(x, [2, 3])))(torch.randn(B0, B1)) 1942*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.shape, [B0, B1, 2, 3]) 1943*da0073e9SAndroid Build Coastguard Worker 1944*da0073e9SAndroid Build Coastguard Worker def test_new_empty_strided(self): 1945*da0073e9SAndroid Build Coastguard Worker # Empty is non-deterministic so we just check that the size and shape 1946*da0073e9SAndroid Build Coastguard Worker # of the output are what we expect and that the vmap fallback isn't used 1947*da0073e9SAndroid Build Coastguard Worker B0, B1 = 7, 11 1948*da0073e9SAndroid Build Coastguard Worker 1949*da0073e9SAndroid Build Coastguard Worker def _test_single_vmap(size, stride, B0): 1950*da0073e9SAndroid Build Coastguard Worker x = torch.randn(B0) 1951*da0073e9SAndroid Build Coastguard Worker result = vmap(lambda x: x.new_empty_strided(size, stride))(x) 1952*da0073e9SAndroid Build Coastguard Worker S = torch.empty_strided(size, stride).storage().size() 1953*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.shape, [B0] + size) 1954*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.stride(), [S] + stride) 1955*da0073e9SAndroid Build Coastguard Worker 1956*da0073e9SAndroid Build Coastguard Worker def _test_double_vmap(size, stride, B0, B1): 1957*da0073e9SAndroid Build Coastguard Worker x = torch.randn(B0, B1) 1958*da0073e9SAndroid Build Coastguard Worker result = vmap(vmap(lambda x: x.new_empty_strided(size, stride)))(x) 1959*da0073e9SAndroid Build Coastguard Worker S = torch.empty_strided(size, stride).storage().size() 1960*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.shape, [B0, B1] + size) 1961*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.stride(), [B1 * S, S] + stride) 1962*da0073e9SAndroid Build Coastguard Worker 1963*da0073e9SAndroid Build Coastguard Worker x = torch.randn(B1, B0) 1964*da0073e9SAndroid Build Coastguard Worker result = vmap(vmap(lambda x: x.new_empty_strided(size, stride)), in_dims=1)( 1965*da0073e9SAndroid Build Coastguard Worker x 1966*da0073e9SAndroid Build Coastguard Worker ) 1967*da0073e9SAndroid Build Coastguard Worker S = x.new_empty_strided(size, stride).storage().size() 1968*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.shape, [B0, B1] + size) 1969*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.stride(), [B1 * S, S] + stride) 1970*da0073e9SAndroid Build Coastguard Worker 1971*da0073e9SAndroid Build Coastguard Worker # contiguous case 1972*da0073e9SAndroid Build Coastguard Worker _test_single_vmap([2, 3, 5], [3 * 5, 5, 1], B0) 1973*da0073e9SAndroid Build Coastguard Worker _test_double_vmap([2, 3, 5], [3 * 5, 5, 1], B0, B1) 1974*da0073e9SAndroid Build Coastguard Worker 1975*da0073e9SAndroid Build Coastguard Worker # expanded 1976*da0073e9SAndroid Build Coastguard Worker _test_single_vmap([2, 3, 5], [0, 5, 1], B0) 1977*da0073e9SAndroid Build Coastguard Worker _test_double_vmap([2, 3, 5], [0, 5, 1], B0, B1) 1978*da0073e9SAndroid Build Coastguard Worker 1979*da0073e9SAndroid Build Coastguard Worker # some of these cases are pretty strange, just verifying that if 1980*da0073e9SAndroid Build Coastguard Worker # empty_strided allows them then BatchedTensor.new_empty_strided 1981*da0073e9SAndroid Build Coastguard Worker # can as well 1982*da0073e9SAndroid Build Coastguard Worker for shape in [[2, 3, 4], [0, 2, 0]]: 1983*da0073e9SAndroid Build Coastguard Worker for strides in [[12, 4, 1], [2, 4, 6], [0, 0, 0]]: 1984*da0073e9SAndroid Build Coastguard Worker _test_single_vmap(shape, strides, B0) 1985*da0073e9SAndroid Build Coastguard Worker _test_double_vmap(shape, strides, B0, B1) 1986*da0073e9SAndroid Build Coastguard Worker 1987*da0073e9SAndroid Build Coastguard Worker def test_new_zeros(self): 1988*da0073e9SAndroid Build Coastguard Worker op = Tensor.new_zeros 1989*da0073e9SAndroid Build Coastguard Worker test = functools.partial(self._vmap_test, check_propagates_grad=False) 1990*da0073e9SAndroid Build Coastguard Worker B0, B1 = 7, 11 1991*da0073e9SAndroid Build Coastguard Worker 1992*da0073e9SAndroid Build Coastguard Worker test(lambda x: op(x, 2, 3), (torch.rand(B0),)) 1993*da0073e9SAndroid Build Coastguard Worker test(lambda x: op(x, []), (torch.rand(B0),)) 1994*da0073e9SAndroid Build Coastguard Worker test(vmap(lambda x: op(x, 3, 5)), (torch.rand(B0, B1),)) 1995*da0073e9SAndroid Build Coastguard Worker 1996*da0073e9SAndroid Build Coastguard Worker def test_select(self): 1997*da0073e9SAndroid Build Coastguard Worker op = torch.select 1998*da0073e9SAndroid Build Coastguard Worker test = self._vmap_view_test 1999*da0073e9SAndroid Build Coastguard Worker B0, B1, B2 = 7, 11, 13 2000*da0073e9SAndroid Build Coastguard Worker test(op, (torch.rand(B0, 2, 5), 0, 0), in_dims=(0, None, None)) 2001*da0073e9SAndroid Build Coastguard Worker test(op, (torch.rand(2, B0, 5), 1, 1), in_dims=(1, None, None)) 2002*da0073e9SAndroid Build Coastguard Worker test(vmap(lambda t: op(t, 1, 1)), (torch.rand(B1, 2, B0, 5),), in_dims=2) 2003*da0073e9SAndroid Build Coastguard Worker test( 2004*da0073e9SAndroid Build Coastguard Worker vmap(vmap(lambda t: op(t, 1, 1), in_dims=1)), 2005*da0073e9SAndroid Build Coastguard Worker (torch.rand(B1, 2, B0, B2, 5),), 2006*da0073e9SAndroid Build Coastguard Worker in_dims=2, 2007*da0073e9SAndroid Build Coastguard Worker ) 2008*da0073e9SAndroid Build Coastguard Worker 2009*da0073e9SAndroid Build Coastguard Worker def test_stack(self): 2010*da0073e9SAndroid Build Coastguard Worker test = self._vmap_test 2011*da0073e9SAndroid Build Coastguard Worker B0, B1 = 5, 7 2012*da0073e9SAndroid Build Coastguard Worker 2013*da0073e9SAndroid Build Coastguard Worker # Quick hack b/c vmap can't accept a list of tensors as an argument 2014*da0073e9SAndroid Build Coastguard Worker def get_op(dim): 2015*da0073e9SAndroid Build Coastguard Worker def op(*tensors): 2016*da0073e9SAndroid Build Coastguard Worker return torch.stack(tensors, dim=dim) 2017*da0073e9SAndroid Build Coastguard Worker 2018*da0073e9SAndroid Build Coastguard Worker return op 2019*da0073e9SAndroid Build Coastguard Worker 2020*da0073e9SAndroid Build Coastguard Worker test(get_op(0), (torch.rand(B0, 3), torch.rand(B0, 3))) 2021*da0073e9SAndroid Build Coastguard Worker test(get_op(0), (torch.rand(3), torch.rand(B0, 3)), in_dims=(None, 0)) 2022*da0073e9SAndroid Build Coastguard Worker test(get_op(0), (torch.rand(2, 17), torch.rand(2, 17, B0)), in_dims=(None, 2)) 2023*da0073e9SAndroid Build Coastguard Worker test(get_op(-1), (torch.rand(2, 17), torch.rand(2, 17, B0)), in_dims=(None, 2)) 2024*da0073e9SAndroid Build Coastguard Worker test( 2025*da0073e9SAndroid Build Coastguard Worker vmap(get_op(0), in_dims=(0, None)), 2026*da0073e9SAndroid Build Coastguard Worker (torch.rand(B1, 2), torch.rand(B0, 2)), 2027*da0073e9SAndroid Build Coastguard Worker in_dims=(None, 0), 2028*da0073e9SAndroid Build Coastguard Worker ) 2029*da0073e9SAndroid Build Coastguard Worker test( 2030*da0073e9SAndroid Build Coastguard Worker vmap(get_op(0), in_dims=(0, 0)), 2031*da0073e9SAndroid Build Coastguard Worker (torch.rand(B1, 2), torch.rand(B0, B1, 2)), 2032*da0073e9SAndroid Build Coastguard Worker in_dims=(None, 0), 2033*da0073e9SAndroid Build Coastguard Worker ) 2034*da0073e9SAndroid Build Coastguard Worker 2035*da0073e9SAndroid Build Coastguard Worker def test_slice(self): 2036*da0073e9SAndroid Build Coastguard Worker test = self._vmap_view_test 2037*da0073e9SAndroid Build Coastguard Worker B0, B1, B2 = 7, 11, 13 2038*da0073e9SAndroid Build Coastguard Worker test(lambda t: t[0:1], (torch.rand(B0, 3, 5),)) 2039*da0073e9SAndroid Build Coastguard Worker test(lambda t: t[:, 1:3], (torch.rand(3, 5, B0),), in_dims=2) 2040*da0073e9SAndroid Build Coastguard Worker test( 2041*da0073e9SAndroid Build Coastguard Worker vmap(lambda t: t[:, 0:1], in_dims=2), (torch.rand(3, 5, B0, B1),), in_dims=2 2042*da0073e9SAndroid Build Coastguard Worker ) 2043*da0073e9SAndroid Build Coastguard Worker test( 2044*da0073e9SAndroid Build Coastguard Worker vmap(vmap(lambda t: t[0:1], in_dims=2), in_dims=2), 2045*da0073e9SAndroid Build Coastguard Worker (torch.rand(3, 5, B0, B1, B2),), 2046*da0073e9SAndroid Build Coastguard Worker in_dims=2, 2047*da0073e9SAndroid Build Coastguard Worker ) 2048*da0073e9SAndroid Build Coastguard Worker 2049*da0073e9SAndroid Build Coastguard Worker def test_squeeze(self): 2050*da0073e9SAndroid Build Coastguard Worker test = self._vmap_view_test 2051*da0073e9SAndroid Build Coastguard Worker op = torch.squeeze 2052*da0073e9SAndroid Build Coastguard Worker B0, B1 = 1, 11 2053*da0073e9SAndroid Build Coastguard Worker test(op, (torch.rand(B0),)) 2054*da0073e9SAndroid Build Coastguard Worker test(op, (torch.rand(B0, 3, 5),)) 2055*da0073e9SAndroid Build Coastguard Worker test(op, (torch.rand(1, B0, 5),), in_dims=1) 2056*da0073e9SAndroid Build Coastguard Worker test(op, (torch.rand(B0, 0, 1, 5, 1),)) 2057*da0073e9SAndroid Build Coastguard Worker test(op, (torch.rand(B0, 1, 1, 1, 1),)) 2058*da0073e9SAndroid Build Coastguard Worker test(vmap(op), (torch.rand(B0, B1, 1),)) 2059*da0073e9SAndroid Build Coastguard Worker test(vmap(op), (torch.rand(B1, 1, B0),), in_dims=2) 2060*da0073e9SAndroid Build Coastguard Worker 2061*da0073e9SAndroid Build Coastguard Worker def test_sum_dim(self): 2062*da0073e9SAndroid Build Coastguard Worker test = self._vmap_test 2063*da0073e9SAndroid Build Coastguard Worker B0, B1 = 5, 7 2064*da0073e9SAndroid Build Coastguard Worker 2065*da0073e9SAndroid Build Coastguard Worker # Single vmap, various in_dims / out_dims 2066*da0073e9SAndroid Build Coastguard Worker test(lambda x: x.sum(()), [torch.randn([B0])]) 2067*da0073e9SAndroid Build Coastguard Worker test(lambda x: x.sum(()), [torch.randn([B0, 2])]) 2068*da0073e9SAndroid Build Coastguard Worker test(lambda x: x.sum(0), [torch.randn([B0])]) 2069*da0073e9SAndroid Build Coastguard Worker test(lambda x: x.sum(-1), [torch.randn([B0])]) 2070*da0073e9SAndroid Build Coastguard Worker test(lambda x: x.sum(0), [torch.randn([B0, 3])]) 2071*da0073e9SAndroid Build Coastguard Worker test(lambda x: x.sum(-1), [torch.randn([2, 5, B0, 3])], in_dims=2) 2072*da0073e9SAndroid Build Coastguard Worker test(lambda x: x.sum(2), [torch.randn([2, 5, B0, 3])], in_dims=2, out_dims=2) 2073*da0073e9SAndroid Build Coastguard Worker 2074*da0073e9SAndroid Build Coastguard Worker # Doubly nested vmap 2075*da0073e9SAndroid Build Coastguard Worker test(vmap(lambda x: x.sum(())), [torch.randn([B0, B1])]) 2076*da0073e9SAndroid Build Coastguard Worker test(vmap(lambda x: x.sum(0)), [torch.randn([B0, B1])]) 2077*da0073e9SAndroid Build Coastguard Worker test(vmap(lambda x: x.sum(-1)), [torch.randn([B0, B1])]) 2078*da0073e9SAndroid Build Coastguard Worker test(vmap(lambda x: x.sum(-2)), [torch.randn([B1, 2, 5, B0, 3])], in_dims=2) 2079*da0073e9SAndroid Build Coastguard Worker test( 2080*da0073e9SAndroid Build Coastguard Worker vmap(lambda x: x.sum(2), in_dims=2), 2081*da0073e9SAndroid Build Coastguard Worker [torch.randn([2, 5, B0, B1, 3])], 2082*da0073e9SAndroid Build Coastguard Worker in_dims=2, 2083*da0073e9SAndroid Build Coastguard Worker out_dims=2, 2084*da0073e9SAndroid Build Coastguard Worker ) 2085*da0073e9SAndroid Build Coastguard Worker 2086*da0073e9SAndroid Build Coastguard Worker def test_reshape(self): 2087*da0073e9SAndroid Build Coastguard Worker test = self._vmap_test 2088*da0073e9SAndroid Build Coastguard Worker B0, B1, B2 = 7, 11, 13 2089*da0073e9SAndroid Build Coastguard Worker op = torch.reshape 2090*da0073e9SAndroid Build Coastguard Worker test(op, (torch.rand(B0, 2 * 5), [2, 5]), in_dims=(0, None), check_view=True) 2091*da0073e9SAndroid Build Coastguard Worker test( 2092*da0073e9SAndroid Build Coastguard Worker op, (torch.rand(2, B0, 5), [1, 1, 10]), in_dims=(1, None), check_view=False 2093*da0073e9SAndroid Build Coastguard Worker ) 2094*da0073e9SAndroid Build Coastguard Worker test( 2095*da0073e9SAndroid Build Coastguard Worker vmap(lambda t: t.reshape([-1])), 2096*da0073e9SAndroid Build Coastguard Worker (torch.rand(B0, B1, 2, 5),), 2097*da0073e9SAndroid Build Coastguard Worker check_view=True, 2098*da0073e9SAndroid Build Coastguard Worker ) 2099*da0073e9SAndroid Build Coastguard Worker test( 2100*da0073e9SAndroid Build Coastguard Worker vmap(vmap(lambda t: t.reshape([-1]), in_dims=2), in_dims=1), 2101*da0073e9SAndroid Build Coastguard Worker (torch.rand(3, B1, 2, B2, 5, B0),), 2102*da0073e9SAndroid Build Coastguard Worker in_dims=5, 2103*da0073e9SAndroid Build Coastguard Worker check_view=False, 2104*da0073e9SAndroid Build Coastguard Worker ) 2105*da0073e9SAndroid Build Coastguard Worker 2106*da0073e9SAndroid Build Coastguard Worker def test_reshape_as(self): 2107*da0073e9SAndroid Build Coastguard Worker test = self._vmap_test 2108*da0073e9SAndroid Build Coastguard Worker B0, B1, B2 = 7, 11, 13 2109*da0073e9SAndroid Build Coastguard Worker op = torch.Tensor.reshape_as 2110*da0073e9SAndroid Build Coastguard Worker test(op, (torch.rand(B0, 2 * 5), torch.rand(B0, 2, 5)), check_view=True) 2111*da0073e9SAndroid Build Coastguard Worker test( 2112*da0073e9SAndroid Build Coastguard Worker op, 2113*da0073e9SAndroid Build Coastguard Worker (torch.rand(2 * 5), torch.rand(B0, 2, 5)), 2114*da0073e9SAndroid Build Coastguard Worker in_dims=(None, 0), 2115*da0073e9SAndroid Build Coastguard Worker check_view=True, 2116*da0073e9SAndroid Build Coastguard Worker ) 2117*da0073e9SAndroid Build Coastguard Worker test( 2118*da0073e9SAndroid Build Coastguard Worker op, 2119*da0073e9SAndroid Build Coastguard Worker (torch.rand(B0, 2 * 5), torch.rand(2, 5)), 2120*da0073e9SAndroid Build Coastguard Worker in_dims=(0, None), 2121*da0073e9SAndroid Build Coastguard Worker check_view=True, 2122*da0073e9SAndroid Build Coastguard Worker ) 2123*da0073e9SAndroid Build Coastguard Worker 2124*da0073e9SAndroid Build Coastguard Worker test( 2125*da0073e9SAndroid Build Coastguard Worker op, 2126*da0073e9SAndroid Build Coastguard Worker (torch.rand(2, B0, 5), torch.rand(1, 1, 10)), 2127*da0073e9SAndroid Build Coastguard Worker in_dims=(1, None), 2128*da0073e9SAndroid Build Coastguard Worker check_view=False, 2129*da0073e9SAndroid Build Coastguard Worker ) 2130*da0073e9SAndroid Build Coastguard Worker 2131*da0073e9SAndroid Build Coastguard Worker test( 2132*da0073e9SAndroid Build Coastguard Worker vmap(op), 2133*da0073e9SAndroid Build Coastguard Worker (torch.rand(B0, B1, 2, 5), torch.randn(B0, B1, 10)), 2134*da0073e9SAndroid Build Coastguard Worker check_view=True, 2135*da0073e9SAndroid Build Coastguard Worker ) 2136*da0073e9SAndroid Build Coastguard Worker test( 2137*da0073e9SAndroid Build Coastguard Worker vmap(vmap(op, in_dims=(2, None)), in_dims=(1, None)), 2138*da0073e9SAndroid Build Coastguard Worker (torch.rand(3, B1, 2, B2, 5, B0), torch.rand(B0, 3 * 2 * 5)), 2139*da0073e9SAndroid Build Coastguard Worker in_dims=(5, 0), 2140*da0073e9SAndroid Build Coastguard Worker check_view=False, 2141*da0073e9SAndroid Build Coastguard Worker ) 2142*da0073e9SAndroid Build Coastguard Worker 2143*da0073e9SAndroid Build Coastguard Worker def test_result_type(self): 2144*da0073e9SAndroid Build Coastguard Worker def scalar_tensor_with_dtype(op): 2145*da0073e9SAndroid Build Coastguard Worker def wrapped(*args, **kwargs): 2146*da0073e9SAndroid Build Coastguard Worker dtype = op(*args, **kwargs) 2147*da0073e9SAndroid Build Coastguard Worker return torch.ones([], dtype=dtype) 2148*da0073e9SAndroid Build Coastguard Worker 2149*da0073e9SAndroid Build Coastguard Worker return wrapped 2150*da0073e9SAndroid Build Coastguard Worker 2151*da0073e9SAndroid Build Coastguard Worker test = self._vmap_test 2152*da0073e9SAndroid Build Coastguard Worker op = scalar_tensor_with_dtype(torch.result_type) 2153*da0073e9SAndroid Build Coastguard Worker 2154*da0073e9SAndroid Build Coastguard Worker B0 = 2 2155*da0073e9SAndroid Build Coastguard Worker 2156*da0073e9SAndroid Build Coastguard Worker test( 2157*da0073e9SAndroid Build Coastguard Worker op, 2158*da0073e9SAndroid Build Coastguard Worker (torch.randn(B0), torch.randn(B0, dtype=torch.float64)), 2159*da0073e9SAndroid Build Coastguard Worker check_propagates_grad=False, 2160*da0073e9SAndroid Build Coastguard Worker ) 2161*da0073e9SAndroid Build Coastguard Worker test( 2162*da0073e9SAndroid Build Coastguard Worker op, 2163*da0073e9SAndroid Build Coastguard Worker (torch.randn(B0), torch.randint(10, [B0], dtype=torch.int64)), 2164*da0073e9SAndroid Build Coastguard Worker check_propagates_grad=False, 2165*da0073e9SAndroid Build Coastguard Worker ) 2166*da0073e9SAndroid Build Coastguard Worker 2167*da0073e9SAndroid Build Coastguard Worker test(lambda x: op(x, 1), (torch.randn(B0),), check_propagates_grad=False) 2168*da0073e9SAndroid Build Coastguard Worker test(lambda x: op(x, 1.6), (torch.randn(B0),), check_propagates_grad=False) 2169*da0073e9SAndroid Build Coastguard Worker 2170*da0073e9SAndroid Build Coastguard Worker test( 2171*da0073e9SAndroid Build Coastguard Worker lambda x: op(x, torch.tensor(1)), 2172*da0073e9SAndroid Build Coastguard Worker (torch.randn(B0),), 2173*da0073e9SAndroid Build Coastguard Worker check_propagates_grad=False, 2174*da0073e9SAndroid Build Coastguard Worker ) 2175*da0073e9SAndroid Build Coastguard Worker test( 2176*da0073e9SAndroid Build Coastguard Worker lambda x: op(x, torch.tensor(1.6, dtype=torch.double)), 2177*da0073e9SAndroid Build Coastguard Worker (torch.randn(B0),), 2178*da0073e9SAndroid Build Coastguard Worker check_propagates_grad=False, 2179*da0073e9SAndroid Build Coastguard Worker ) 2180*da0073e9SAndroid Build Coastguard Worker 2181*da0073e9SAndroid Build Coastguard Worker test( 2182*da0073e9SAndroid Build Coastguard Worker op, 2183*da0073e9SAndroid Build Coastguard Worker (torch.randn(B0, 2), torch.randn(B0, 2, dtype=torch.float64)), 2184*da0073e9SAndroid Build Coastguard Worker check_propagates_grad=False, 2185*da0073e9SAndroid Build Coastguard Worker ) 2186*da0073e9SAndroid Build Coastguard Worker test( 2187*da0073e9SAndroid Build Coastguard Worker op, 2188*da0073e9SAndroid Build Coastguard Worker (torch.randn(B0, 2), torch.randint(10, [B0, 2], dtype=torch.int64)), 2189*da0073e9SAndroid Build Coastguard Worker check_propagates_grad=False, 2190*da0073e9SAndroid Build Coastguard Worker ) 2191*da0073e9SAndroid Build Coastguard Worker 2192*da0073e9SAndroid Build Coastguard Worker test(lambda x: op(x, 1), (torch.randn(B0, 2),), check_propagates_grad=False) 2193*da0073e9SAndroid Build Coastguard Worker test(lambda x: op(x, 1.6), (torch.randn(B0, 2),), check_propagates_grad=False) 2194*da0073e9SAndroid Build Coastguard Worker 2195*da0073e9SAndroid Build Coastguard Worker test( 2196*da0073e9SAndroid Build Coastguard Worker lambda x: op(x, torch.tensor(1)), 2197*da0073e9SAndroid Build Coastguard Worker (torch.randn(B0, 2),), 2198*da0073e9SAndroid Build Coastguard Worker check_propagates_grad=False, 2199*da0073e9SAndroid Build Coastguard Worker ) 2200*da0073e9SAndroid Build Coastguard Worker test( 2201*da0073e9SAndroid Build Coastguard Worker lambda x: op(x, torch.tensor(1.6, dtype=torch.double)), 2202*da0073e9SAndroid Build Coastguard Worker (torch.randn(B0, 2),), 2203*da0073e9SAndroid Build Coastguard Worker check_propagates_grad=False, 2204*da0073e9SAndroid Build Coastguard Worker ) 2205*da0073e9SAndroid Build Coastguard Worker 2206*da0073e9SAndroid Build Coastguard Worker test( 2207*da0073e9SAndroid Build Coastguard Worker op, 2208*da0073e9SAndroid Build Coastguard Worker (torch.randn(B0, 2), torch.randn(B0, dtype=torch.float64)), 2209*da0073e9SAndroid Build Coastguard Worker check_propagates_grad=False, 2210*da0073e9SAndroid Build Coastguard Worker ) 2211*da0073e9SAndroid Build Coastguard Worker test( 2212*da0073e9SAndroid Build Coastguard Worker op, 2213*da0073e9SAndroid Build Coastguard Worker (torch.randn(B0, 2), torch.randint(10, [B0], dtype=torch.int64)), 2214*da0073e9SAndroid Build Coastguard Worker check_propagates_grad=False, 2215*da0073e9SAndroid Build Coastguard Worker ) 2216*da0073e9SAndroid Build Coastguard Worker 2217*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("too slow") 2218*da0073e9SAndroid Build Coastguard Worker def test_tensor_split(self): 2219*da0073e9SAndroid Build Coastguard Worker test = self._vmap_view_test 2220*da0073e9SAndroid Build Coastguard Worker op = torch.tensor_split 2221*da0073e9SAndroid Build Coastguard Worker B0, B1, B2 = 7, 11, 13 2222*da0073e9SAndroid Build Coastguard Worker 2223*da0073e9SAndroid Build Coastguard Worker # tests for torch.tensor_split(self, indices_or_sections: int, dim) 2224*da0073e9SAndroid Build Coastguard Worker test(op, (torch.rand(B0, 2, 1024), 5, -1), in_dims=(0, None, None)) 2225*da0073e9SAndroid Build Coastguard Worker test(op, (torch.rand(2, B0, 1024), 150, 1), in_dims=(1, None, None)) 2226*da0073e9SAndroid Build Coastguard Worker test( 2227*da0073e9SAndroid Build Coastguard Worker vmap(op, in_dims=(0, None, None)), 2228*da0073e9SAndroid Build Coastguard Worker (torch.rand(B1, 1023, B0, 5), 256, 0), 2229*da0073e9SAndroid Build Coastguard Worker in_dims=(2, None, None), 2230*da0073e9SAndroid Build Coastguard Worker ) 2231*da0073e9SAndroid Build Coastguard Worker test( 2232*da0073e9SAndroid Build Coastguard Worker vmap(vmap(lambda t: op(t, 4, 1), in_dims=2)), 2233*da0073e9SAndroid Build Coastguard Worker (torch.rand(B1, 2, B0, 64, B2),), 2234*da0073e9SAndroid Build Coastguard Worker in_dims=2, 2235*da0073e9SAndroid Build Coastguard Worker ) 2236*da0073e9SAndroid Build Coastguard Worker 2237*da0073e9SAndroid Build Coastguard Worker # tests for torch.tensor_split(self, indices_or_sections: List[int], dim) 2238*da0073e9SAndroid Build Coastguard Worker test( 2239*da0073e9SAndroid Build Coastguard Worker op, 2240*da0073e9SAndroid Build Coastguard Worker (torch.rand(B0, 2, 1024), [50, 100, 378, 890], -1), 2241*da0073e9SAndroid Build Coastguard Worker in_dims=(0, None, None), 2242*da0073e9SAndroid Build Coastguard Worker ) 2243*da0073e9SAndroid Build Coastguard Worker test( 2244*da0073e9SAndroid Build Coastguard Worker op, 2245*da0073e9SAndroid Build Coastguard Worker (torch.rand(2, B0, 1024), [50, 100, 212, 345, 0, 378, 890], 1), 2246*da0073e9SAndroid Build Coastguard Worker in_dims=(1, None, None), 2247*da0073e9SAndroid Build Coastguard Worker ) 2248*da0073e9SAndroid Build Coastguard Worker test( 2249*da0073e9SAndroid Build Coastguard Worker vmap(op, in_dims=(0, None, None)), 2250*da0073e9SAndroid Build Coastguard Worker (torch.rand(B1, 1023, B0, 5), [50, 100, 212, 345, 0, 378, 890], 0), 2251*da0073e9SAndroid Build Coastguard Worker in_dims=(2, None, None), 2252*da0073e9SAndroid Build Coastguard Worker ) 2253*da0073e9SAndroid Build Coastguard Worker test( 2254*da0073e9SAndroid Build Coastguard Worker vmap(vmap(lambda t: op(t, [4, 8, 9, 34, 29], 1), in_dims=2)), 2255*da0073e9SAndroid Build Coastguard Worker (torch.rand(B1, 2, B0, 64, B2),), 2256*da0073e9SAndroid Build Coastguard Worker in_dims=2, 2257*da0073e9SAndroid Build Coastguard Worker ) 2258*da0073e9SAndroid Build Coastguard Worker 2259*da0073e9SAndroid Build Coastguard Worker def test_split(self): 2260*da0073e9SAndroid Build Coastguard Worker test = self._vmap_view_test 2261*da0073e9SAndroid Build Coastguard Worker op = torch.split 2262*da0073e9SAndroid Build Coastguard Worker B0, B1, B2 = 7, 11, 13 2263*da0073e9SAndroid Build Coastguard Worker 2264*da0073e9SAndroid Build Coastguard Worker # tests for torch.split(self, split_size: int, dim) 2265*da0073e9SAndroid Build Coastguard Worker test(op, (torch.rand(B0, 2, 1024), 101, -1), in_dims=(0, None, None)) 2266*da0073e9SAndroid Build Coastguard Worker test(op, (torch.rand(2, B0, 1024), 130, 1), in_dims=(1, None, None)) 2267*da0073e9SAndroid Build Coastguard Worker test( 2268*da0073e9SAndroid Build Coastguard Worker vmap(op, in_dims=(0, None, None)), 2269*da0073e9SAndroid Build Coastguard Worker (torch.rand(B1, 1023, B0, 5), 256, 0), 2270*da0073e9SAndroid Build Coastguard Worker in_dims=(2, None, None), 2271*da0073e9SAndroid Build Coastguard Worker ) 2272*da0073e9SAndroid Build Coastguard Worker test( 2273*da0073e9SAndroid Build Coastguard Worker vmap(vmap(lambda t: op(t, 4, 1), in_dims=2)), 2274*da0073e9SAndroid Build Coastguard Worker (torch.rand(B1, 2, B0, 64, B2),), 2275*da0073e9SAndroid Build Coastguard Worker in_dims=2, 2276*da0073e9SAndroid Build Coastguard Worker ) 2277*da0073e9SAndroid Build Coastguard Worker 2278*da0073e9SAndroid Build Coastguard Worker # tests for torch.split(self, split_size: List[int], dim) 2279*da0073e9SAndroid Build Coastguard Worker test(op, (torch.rand(B0, 2, 1024), [1, 1020, 3], -1), in_dims=(0, None, None)) 2280*da0073e9SAndroid Build Coastguard Worker test( 2281*da0073e9SAndroid Build Coastguard Worker op, (torch.rand(2, B0, 1024), [100] * 10 + [24], 1), in_dims=(1, None, None) 2282*da0073e9SAndroid Build Coastguard Worker ) 2283*da0073e9SAndroid Build Coastguard Worker test( 2284*da0073e9SAndroid Build Coastguard Worker vmap(op, in_dims=(0, None, None)), 2285*da0073e9SAndroid Build Coastguard Worker (torch.rand(B1, 1023, B0, 5), [256] * 3 + [255], 0), 2286*da0073e9SAndroid Build Coastguard Worker in_dims=(2, None, None), 2287*da0073e9SAndroid Build Coastguard Worker ) 2288*da0073e9SAndroid Build Coastguard Worker test( 2289*da0073e9SAndroid Build Coastguard Worker vmap(vmap(lambda t: op(t, [4] * 8 + [8] * 4, 1), in_dims=2)), 2290*da0073e9SAndroid Build Coastguard Worker (torch.rand(B1, 2, B0, 64, B2),), 2291*da0073e9SAndroid Build Coastguard Worker in_dims=2, 2292*da0073e9SAndroid Build Coastguard Worker ) 2293*da0073e9SAndroid Build Coastguard Worker 2294*da0073e9SAndroid Build Coastguard Worker def test_trace(self): 2295*da0073e9SAndroid Build Coastguard Worker op = torch.trace 2296*da0073e9SAndroid Build Coastguard Worker test = self._vmap_test 2297*da0073e9SAndroid Build Coastguard Worker B0, B1, B2 = 7, 11, 13 2298*da0073e9SAndroid Build Coastguard Worker 2299*da0073e9SAndroid Build Coastguard Worker test(op, (torch.rand(B0, 2, 5),)) 2300*da0073e9SAndroid Build Coastguard Worker test(op, (torch.rand(2, B0, 5),), in_dims=1) 2301*da0073e9SAndroid Build Coastguard Worker test(vmap(op), (torch.rand(B1, 2, B0, 5),), in_dims=2) 2302*da0073e9SAndroid Build Coastguard Worker test(vmap(vmap(op, in_dims=2)), (torch.rand(B1, 2, B0, 5, B2),), in_dims=2) 2303*da0073e9SAndroid Build Coastguard Worker 2304*da0073e9SAndroid Build Coastguard Worker def test_transpose(self): 2305*da0073e9SAndroid Build Coastguard Worker op = torch.transpose 2306*da0073e9SAndroid Build Coastguard Worker test = self._vmap_view_test 2307*da0073e9SAndroid Build Coastguard Worker 2308*da0073e9SAndroid Build Coastguard Worker B0, B1, B2 = 7, 11, 13 2309*da0073e9SAndroid Build Coastguard Worker test(lambda x: op(x, 0, 1), (torch.rand(B0, 2, 5),)) 2310*da0073e9SAndroid Build Coastguard Worker test(lambda x: op(x, -1, -2), (torch.rand(B0, 2, 5),)) 2311*da0073e9SAndroid Build Coastguard Worker test(lambda x: op(x, 3, 1), (torch.rand(B0, 2, 5, 4, 6),)) 2312*da0073e9SAndroid Build Coastguard Worker test(lambda x: op(x, 1, 0), (torch.rand(2, B0, 5),), in_dims=1) 2313*da0073e9SAndroid Build Coastguard Worker test(vmap(lambda x: op(x, 0, 1)), (torch.rand(B1, 2, B0, 5),), in_dims=2) 2314*da0073e9SAndroid Build Coastguard Worker test( 2315*da0073e9SAndroid Build Coastguard Worker vmap(vmap(lambda x: op(x, 0, 1), in_dims=2)), 2316*da0073e9SAndroid Build Coastguard Worker (torch.rand(B1, 2, B0, 5, B2),), 2317*da0073e9SAndroid Build Coastguard Worker in_dims=2, 2318*da0073e9SAndroid Build Coastguard Worker ) 2319*da0073e9SAndroid Build Coastguard Worker 2320*da0073e9SAndroid Build Coastguard Worker # Special case: scalar tensor 2321*da0073e9SAndroid Build Coastguard Worker for dim1, dim2 in itertools.product([0, -1], [0, -1]): 2322*da0073e9SAndroid Build Coastguard Worker x = torch.rand(B0) 2323*da0073e9SAndroid Build Coastguard Worker result = vmap(lambda x: op(x, dim1, dim2))(x) 2324*da0073e9SAndroid Build Coastguard Worker self.assertTrue(result is x) 2325*da0073e9SAndroid Build Coastguard Worker 2326*da0073e9SAndroid Build Coastguard Worker def test_t(self): 2327*da0073e9SAndroid Build Coastguard Worker op = torch.t 2328*da0073e9SAndroid Build Coastguard Worker test = self._vmap_view_test 2329*da0073e9SAndroid Build Coastguard Worker B0, B1, B2 = 7, 11, 13 2330*da0073e9SAndroid Build Coastguard Worker test(op, (torch.rand(B0, 2, 5),)) 2331*da0073e9SAndroid Build Coastguard Worker test(op, (torch.rand(2, B0, 5),), in_dims=1) 2332*da0073e9SAndroid Build Coastguard Worker test(vmap(op), (torch.rand(B1, 2, B0, 5),), in_dims=2) 2333*da0073e9SAndroid Build Coastguard Worker test(vmap(vmap(op, in_dims=2)), (torch.rand(B1, 2, B0, 5, B2),), in_dims=2) 2334*da0073e9SAndroid Build Coastguard Worker 2335*da0073e9SAndroid Build Coastguard Worker def test_T_numpy(self): 2336*da0073e9SAndroid Build Coastguard Worker def op(t): 2337*da0073e9SAndroid Build Coastguard Worker return t.T 2338*da0073e9SAndroid Build Coastguard Worker 2339*da0073e9SAndroid Build Coastguard Worker test = self._vmap_view_test 2340*da0073e9SAndroid Build Coastguard Worker B0, B1, B2 = 7, 11, 13 2341*da0073e9SAndroid Build Coastguard Worker test(op, (torch.rand(B0, 2, 3, 5),)) 2342*da0073e9SAndroid Build Coastguard Worker test(op, (torch.rand(2, B0, 3, 5),), in_dims=1) 2343*da0073e9SAndroid Build Coastguard Worker test(vmap(op), (torch.rand(B1, 2, B0, 5),), in_dims=2) 2344*da0073e9SAndroid Build Coastguard Worker test(vmap(op), (torch.rand(B1, 2, B0, 3, 5),), in_dims=2) 2345*da0073e9SAndroid Build Coastguard Worker test(vmap(vmap(op, in_dims=2)), (torch.rand(B1, 2, B0, 3, B2, 5),), in_dims=2) 2346*da0073e9SAndroid Build Coastguard Worker 2347*da0073e9SAndroid Build Coastguard Worker def test_to(self): 2348*da0073e9SAndroid Build Coastguard Worker test = self._vmap_test 2349*da0073e9SAndroid Build Coastguard Worker B0, B1 = 7, 11 2350*da0073e9SAndroid Build Coastguard Worker 2351*da0073e9SAndroid Build Coastguard Worker test(lambda t: t.to("cpu"), (torch.rand(B0),)) 2352*da0073e9SAndroid Build Coastguard Worker test(lambda t: t.to(torch.double), (torch.rand(B0),)) 2353*da0073e9SAndroid Build Coastguard Worker test( 2354*da0073e9SAndroid Build Coastguard Worker lambda t, o: t.to(o), (torch.rand(B0), torch.randn(B0, dtype=torch.float64)) 2355*da0073e9SAndroid Build Coastguard Worker ) 2356*da0073e9SAndroid Build Coastguard Worker test( 2357*da0073e9SAndroid Build Coastguard Worker lambda t, o: t.to(o), 2358*da0073e9SAndroid Build Coastguard Worker (torch.rand(B0), torch.randn(B0, dtype=torch.float64)), 2359*da0073e9SAndroid Build Coastguard Worker in_dims=(0, None), 2360*da0073e9SAndroid Build Coastguard Worker ) 2361*da0073e9SAndroid Build Coastguard Worker test(vmap(lambda t: t.to(torch.double)), (torch.rand(B0, B1, 3),)) 2362*da0073e9SAndroid Build Coastguard Worker 2363*da0073e9SAndroid Build Coastguard Worker # also test some casting methods 2364*da0073e9SAndroid Build Coastguard Worker test(lambda t: t.double(), (torch.rand(B0),)) 2365*da0073e9SAndroid Build Coastguard Worker test(lambda t: t.float(), (torch.rand(B0),)) 2366*da0073e9SAndroid Build Coastguard Worker test(lambda t: t.int(), (torch.rand(B0),), check_propagates_grad=False) 2367*da0073e9SAndroid Build Coastguard Worker test(lambda t: t.long(), (torch.rand(B0),), check_propagates_grad=False) 2368*da0073e9SAndroid Build Coastguard Worker 2369*da0073e9SAndroid Build Coastguard Worker def test_unfold(self): 2370*da0073e9SAndroid Build Coastguard Worker op = torch.Tensor.unfold 2371*da0073e9SAndroid Build Coastguard Worker test = self._vmap_view_test 2372*da0073e9SAndroid Build Coastguard Worker B0, B1, B2 = 3, 2, 5 2373*da0073e9SAndroid Build Coastguard Worker 2374*da0073e9SAndroid Build Coastguard Worker test(op, (torch.rand(B0, 7, 11), 0, 2, 1), in_dims=(0, None, None, None)) 2375*da0073e9SAndroid Build Coastguard Worker test(op, (torch.rand(7, B0, 11), 1, 4, 2), in_dims=(1, None, None, None)) 2376*da0073e9SAndroid Build Coastguard Worker test( 2377*da0073e9SAndroid Build Coastguard Worker vmap(op, in_dims=(0, None, None, None)), 2378*da0073e9SAndroid Build Coastguard Worker (torch.rand(B1, 7, B0, 11), 1, 5, 1), 2379*da0073e9SAndroid Build Coastguard Worker in_dims=(2, None, None, None), 2380*da0073e9SAndroid Build Coastguard Worker ) 2381*da0073e9SAndroid Build Coastguard Worker test( 2382*da0073e9SAndroid Build Coastguard Worker vmap( 2383*da0073e9SAndroid Build Coastguard Worker vmap(op, in_dims=(2, None, None, None)), in_dims=(0, None, None, None) 2384*da0073e9SAndroid Build Coastguard Worker ), 2385*da0073e9SAndroid Build Coastguard Worker (torch.rand(B1, 7, B0, 11, B2), -1, 2, 4), 2386*da0073e9SAndroid Build Coastguard Worker in_dims=(2, None, None, None), 2387*da0073e9SAndroid Build Coastguard Worker ) 2388*da0073e9SAndroid Build Coastguard Worker 2389*da0073e9SAndroid Build Coastguard Worker def test_unbind(self): 2390*da0073e9SAndroid Build Coastguard Worker test = self._vmap_view_test 2391*da0073e9SAndroid Build Coastguard Worker op = torch.unbind 2392*da0073e9SAndroid Build Coastguard Worker B0, B1, B2 = 7, 11, 13 2393*da0073e9SAndroid Build Coastguard Worker 2394*da0073e9SAndroid Build Coastguard Worker test(op, (torch.rand(B0, 2, 1024), -1), in_dims=(0, None)) 2395*da0073e9SAndroid Build Coastguard Worker test(op, (torch.rand(B0, 2, 0),)) 2396*da0073e9SAndroid Build Coastguard Worker test(op, (torch.rand(2, B0, 7), 0), in_dims=(1, None)) 2397*da0073e9SAndroid Build Coastguard Worker test( 2398*da0073e9SAndroid Build Coastguard Worker vmap(op, in_dims=(0, None)), 2399*da0073e9SAndroid Build Coastguard Worker (torch.rand(B1, 1023, B0, 5), 1), 2400*da0073e9SAndroid Build Coastguard Worker in_dims=(2, None), 2401*da0073e9SAndroid Build Coastguard Worker ) 2402*da0073e9SAndroid Build Coastguard Worker test( 2403*da0073e9SAndroid Build Coastguard Worker vmap(vmap(lambda t: op(t, dim=1), in_dims=2)), 2404*da0073e9SAndroid Build Coastguard Worker (torch.rand(B1, 2, B0, 32, B2),), 2405*da0073e9SAndroid Build Coastguard Worker in_dims=2, 2406*da0073e9SAndroid Build Coastguard Worker ) 2407*da0073e9SAndroid Build Coastguard Worker 2408*da0073e9SAndroid Build Coastguard Worker def test_view(self): 2409*da0073e9SAndroid Build Coastguard Worker test = self._vmap_view_test 2410*da0073e9SAndroid Build Coastguard Worker B0, B1, B2 = 7, 11, 13 2411*da0073e9SAndroid Build Coastguard Worker op = torch.Tensor.view 2412*da0073e9SAndroid Build Coastguard Worker 2413*da0073e9SAndroid Build Coastguard Worker # We should error out if the view would produce an incorrect result 2414*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 2415*da0073e9SAndroid Build Coastguard Worker vmap(op, in_dims=(1, None))(torch.rand(2, B0, 5), [10]) 2416*da0073e9SAndroid Build Coastguard Worker 2417*da0073e9SAndroid Build Coastguard Worker test(op, (torch.rand(B0, 2 * 5), [2, 5]), in_dims=(0, None)) 2418*da0073e9SAndroid Build Coastguard Worker test(op, (torch.rand(B0, 4, 5), [1, 2, 1, 10]), in_dims=(0, None)) 2419*da0073e9SAndroid Build Coastguard Worker test(vmap(lambda t: t.view([-1])), (torch.rand(B0, B1, 2, 5, 3),)) 2420*da0073e9SAndroid Build Coastguard Worker test( 2421*da0073e9SAndroid Build Coastguard Worker vmap(vmap(lambda t: t.reshape([-1])), in_dims=1), 2422*da0073e9SAndroid Build Coastguard Worker (torch.rand(B2, B0, B1, 3, 2, 5),), 2423*da0073e9SAndroid Build Coastguard Worker in_dims=1, 2424*da0073e9SAndroid Build Coastguard Worker ) 2425*da0073e9SAndroid Build Coastguard Worker 2426*da0073e9SAndroid Build Coastguard Worker def test_view_as(self): 2427*da0073e9SAndroid Build Coastguard Worker test = self._vmap_view_test 2428*da0073e9SAndroid Build Coastguard Worker B0, B1, B2 = 7, 11, 13 2429*da0073e9SAndroid Build Coastguard Worker op = torch.Tensor.view_as 2430*da0073e9SAndroid Build Coastguard Worker 2431*da0073e9SAndroid Build Coastguard Worker # We should error out if the view would produce an incorrect result 2432*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 2433*da0073e9SAndroid Build Coastguard Worker vmap(op, in_dims=(1, 0))(torch.rand(2, B0, 5), torch.rand(B0, 10)) 2434*da0073e9SAndroid Build Coastguard Worker 2435*da0073e9SAndroid Build Coastguard Worker test(op, (torch.rand(B0, 2 * 5), torch.rand(B0, 2, 5))) 2436*da0073e9SAndroid Build Coastguard Worker test(op, (torch.rand(2 * 5), torch.rand(B0, 2, 5)), in_dims=(None, 0)) 2437*da0073e9SAndroid Build Coastguard Worker test(op, (torch.rand(B0, 2 * 5), torch.rand(2, 5)), in_dims=(0, None)) 2438*da0073e9SAndroid Build Coastguard Worker 2439*da0073e9SAndroid Build Coastguard Worker test(op, (torch.rand(B0, 4, 5), torch.rand(2, 1, 1, 10)), in_dims=(0, None)) 2440*da0073e9SAndroid Build Coastguard Worker 2441*da0073e9SAndroid Build Coastguard Worker test(vmap(op), (torch.rand(B0, B1, 2, 5), torch.randn(B0, B1, 10))) 2442*da0073e9SAndroid Build Coastguard Worker test( 2443*da0073e9SAndroid Build Coastguard Worker vmap(vmap(op, in_dims=(0, None)), in_dims=(0, None)), 2444*da0073e9SAndroid Build Coastguard Worker (torch.rand(B1, B2, B0, 3, 2, 5), torch.rand(B0, 3 * 2 * 5)), 2445*da0073e9SAndroid Build Coastguard Worker in_dims=(2, 0), 2446*da0073e9SAndroid Build Coastguard Worker ) 2447*da0073e9SAndroid Build Coastguard Worker 2448*da0073e9SAndroid Build Coastguard Worker def test_no_random_op_support(self): 2449*da0073e9SAndroid Build Coastguard Worker B0 = 2 2450*da0073e9SAndroid Build Coastguard Worker 2451*da0073e9SAndroid Build Coastguard Worker captured = torch.rand(3) 2452*da0073e9SAndroid Build Coastguard Worker 2453*da0073e9SAndroid Build Coastguard Worker random_ops = [ 2454*da0073e9SAndroid Build Coastguard Worker # out-of-place on BatchedTensor 2455*da0073e9SAndroid Build Coastguard Worker (torch.bernoulli, (torch.rand(B0, 1),)), 2456*da0073e9SAndroid Build Coastguard Worker (lambda t: torch.bernoulli(t, p=0.5), (torch.rand(B0, 1),)), 2457*da0073e9SAndroid Build Coastguard Worker (lambda t: torch.multinomial(t, 2), (torch.rand(B0, 3),)), 2458*da0073e9SAndroid Build Coastguard Worker (torch.normal, (torch.randn(B0, 1), torch.randn(B0, 1))), 2459*da0073e9SAndroid Build Coastguard Worker (lambda t: torch.normal(t, 1.0), (torch.randn(B0, 1),)), 2460*da0073e9SAndroid Build Coastguard Worker (lambda t: torch.normal(0.0, t), (torch.randn(B0, 1),)), 2461*da0073e9SAndroid Build Coastguard Worker (torch.poisson, (torch.rand(B0, 1),)), 2462*da0073e9SAndroid Build Coastguard Worker (torch.rand_like, (torch.rand(B0, 1),)), 2463*da0073e9SAndroid Build Coastguard Worker (torch.randn_like, (torch.rand(B0, 1),)), 2464*da0073e9SAndroid Build Coastguard Worker (lambda t: torch.randint_like(t, 2), (torch.rand(B0, 1),)), 2465*da0073e9SAndroid Build Coastguard Worker (lambda t: torch.randint_like(t, 0, 2), (torch.rand(B0, 1),)), 2466*da0073e9SAndroid Build Coastguard Worker # out-of-place on captured tensor 2467*da0073e9SAndroid Build Coastguard Worker (lambda t: torch.bernoulli(captured), (torch.rand(B0),)), 2468*da0073e9SAndroid Build Coastguard Worker (lambda t: torch.bernoulli(captured, p=0.5), (torch.rand(B0),)), 2469*da0073e9SAndroid Build Coastguard Worker (lambda t: torch.multinomial(captured, 2), (torch.rand(B0),)), 2470*da0073e9SAndroid Build Coastguard Worker (lambda t: torch.normal(captured, captured), (torch.randn(B0),)), 2471*da0073e9SAndroid Build Coastguard Worker (lambda t: torch.normal(captured, 1.0), (torch.randn(B0),)), 2472*da0073e9SAndroid Build Coastguard Worker (lambda t: torch.normal(0.0, captured), (torch.randn(B0),)), 2473*da0073e9SAndroid Build Coastguard Worker (lambda t: torch.poisson(captured), (torch.rand(B0),)), 2474*da0073e9SAndroid Build Coastguard Worker (lambda t: torch.rand_like(captured), (torch.rand(B0),)), 2475*da0073e9SAndroid Build Coastguard Worker (lambda t: torch.randn_like(captured), (torch.rand(B0),)), 2476*da0073e9SAndroid Build Coastguard Worker (lambda t: torch.randint_like(captured, 2), (torch.rand(B0),)), 2477*da0073e9SAndroid Build Coastguard Worker (lambda t: torch.randint_like(captured, 0, 2), (torch.rand(B0),)), 2478*da0073e9SAndroid Build Coastguard Worker # in-place on BatchedTensor 2479*da0073e9SAndroid Build Coastguard Worker (lambda t: t.bernoulli_(), (torch.randn(B0, 1),)), 2480*da0073e9SAndroid Build Coastguard Worker (lambda t: t.cauchy_(), (torch.randn(B0, 1),)), 2481*da0073e9SAndroid Build Coastguard Worker (lambda t: t.exponential_(), (torch.randn(B0, 1),)), 2482*da0073e9SAndroid Build Coastguard Worker (lambda t: t.geometric_(0.5), (torch.randn(B0, 1),)), 2483*da0073e9SAndroid Build Coastguard Worker (lambda t: t.log_normal_(), (torch.randn(B0, 1),)), 2484*da0073e9SAndroid Build Coastguard Worker (lambda t: t.normal_(), (torch.randn(B0, 1),)), 2485*da0073e9SAndroid Build Coastguard Worker (lambda t: t.random_(), (torch.randn(B0, 1),)), 2486*da0073e9SAndroid Build Coastguard Worker (lambda t: t.random_(0, 2), (torch.randn(B0, 1),)), 2487*da0073e9SAndroid Build Coastguard Worker (lambda t: t.random_(2), (torch.randn(B0, 1),)), 2488*da0073e9SAndroid Build Coastguard Worker (lambda t: t.uniform_(), (torch.randn(B0, 1),)), 2489*da0073e9SAndroid Build Coastguard Worker # in-place on captured tensor 2490*da0073e9SAndroid Build Coastguard Worker (lambda t: captured.bernoulli_(), (torch.randn(B0),)), 2491*da0073e9SAndroid Build Coastguard Worker (lambda t: captured.cauchy_(), (torch.randn(B0),)), 2492*da0073e9SAndroid Build Coastguard Worker (lambda t: captured.exponential_(), (torch.randn(B0),)), 2493*da0073e9SAndroid Build Coastguard Worker (lambda t: captured.geometric_(0.5), (torch.randn(B0),)), 2494*da0073e9SAndroid Build Coastguard Worker (lambda t: captured.log_normal_(), (torch.randn(B0),)), 2495*da0073e9SAndroid Build Coastguard Worker (lambda t: captured.normal_(), (torch.randn(B0),)), 2496*da0073e9SAndroid Build Coastguard Worker (lambda t: captured.random_(), (torch.randn(B0),)), 2497*da0073e9SAndroid Build Coastguard Worker (lambda t: captured.random_(0, 2), (torch.randn(B0),)), 2498*da0073e9SAndroid Build Coastguard Worker (lambda t: captured.random_(2), (torch.randn(B0),)), 2499*da0073e9SAndroid Build Coastguard Worker (lambda t: captured.uniform_(), (torch.randn(B0),)), 2500*da0073e9SAndroid Build Coastguard Worker # factory functions 2501*da0073e9SAndroid Build Coastguard Worker (lambda t: torch.rand(1), (torch.randn(B0),)), 2502*da0073e9SAndroid Build Coastguard Worker (lambda t: torch.randn(1), (torch.randn(B0),)), 2503*da0073e9SAndroid Build Coastguard Worker (lambda t: torch.randint(5, [1]), (torch.randn(B0),)), 2504*da0073e9SAndroid Build Coastguard Worker (lambda t: torch.randperm(5), (torch.randn(B0),)), 2505*da0073e9SAndroid Build Coastguard Worker ] 2506*da0073e9SAndroid Build Coastguard Worker for op, args in random_ops: 2507*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 2508*da0073e9SAndroid Build Coastguard Worker RuntimeError, "vmap: We do not yet support calling random operations" 2509*da0073e9SAndroid Build Coastguard Worker ): 2510*da0073e9SAndroid Build Coastguard Worker vmap(op)(*args) 2511*da0073e9SAndroid Build Coastguard Worker 2512*da0073e9SAndroid Build Coastguard Worker 2513*da0073e9SAndroid Build Coastguard Workerdef construct_v(output, batch_size): 2514*da0073e9SAndroid Build Coastguard Worker return torch.randn( 2515*da0073e9SAndroid Build Coastguard Worker batch_size, *output.shape, dtype=output.dtype, device=output.device 2516*da0073e9SAndroid Build Coastguard Worker ) 2517*da0073e9SAndroid Build Coastguard Worker 2518*da0073e9SAndroid Build Coastguard Worker 2519*da0073e9SAndroid Build Coastguard Workerdef as_tuple(x): 2520*da0073e9SAndroid Build Coastguard Worker if isinstance(x, tuple): 2521*da0073e9SAndroid Build Coastguard Worker return x 2522*da0073e9SAndroid Build Coastguard Worker elif isinstance(x, list): 2523*da0073e9SAndroid Build Coastguard Worker return tuple(x) 2524*da0073e9SAndroid Build Coastguard Worker else: 2525*da0073e9SAndroid Build Coastguard Worker return (x,) 2526*da0073e9SAndroid Build Coastguard Worker 2527*da0073e9SAndroid Build Coastguard Worker 2528*da0073e9SAndroid Build Coastguard Workerdef differentiable(args): 2529*da0073e9SAndroid Build Coastguard Worker return tuple( 2530*da0073e9SAndroid Build Coastguard Worker arg 2531*da0073e9SAndroid Build Coastguard Worker for arg in as_tuple(args) 2532*da0073e9SAndroid Build Coastguard Worker if isinstance(arg, torch.Tensor) and arg.requires_grad 2533*da0073e9SAndroid Build Coastguard Worker ) 2534*da0073e9SAndroid Build Coastguard Worker 2535*da0073e9SAndroid Build Coastguard Worker 2536*da0073e9SAndroid Build Coastguard Workerdef _get_rand_no_zeros(*args, **kwargs): 2537*da0073e9SAndroid Build Coastguard Worker requires_grad = kwargs.get("requires_grad", False) 2538*da0073e9SAndroid Build Coastguard Worker kwargs_without_requires_grad = kwargs.copy() 2539*da0073e9SAndroid Build Coastguard Worker kwargs_without_requires_grad["requires_grad"] = False 2540*da0073e9SAndroid Build Coastguard Worker result = torch.rand(*args, **kwargs_without_requires_grad) 2541*da0073e9SAndroid Build Coastguard Worker return result.clamp_min_(0.1).requires_grad_(requires_grad) 2542*da0073e9SAndroid Build Coastguard Worker 2543*da0073e9SAndroid Build Coastguard Worker 2544*da0073e9SAndroid Build Coastguard Workerclass TestVmapBatchedGradientLegacy(Namespace.TestVmapBaseLegacy): 2545*da0073e9SAndroid Build Coastguard Worker def _vmap_test(self, *args, **kwargs): 2546*da0073e9SAndroid Build Coastguard Worker return _vmap_test(self, *args, **kwargs) 2547*da0073e9SAndroid Build Coastguard Worker 2548*da0073e9SAndroid Build Coastguard Worker # Tests batched gradient computation of outputs = op(*args, **kwargs) 2549*da0073e9SAndroid Build Coastguard Worker # by comparing it to a sequential map+stack fallback. 2550*da0073e9SAndroid Build Coastguard Worker # 2551*da0073e9SAndroid Build Coastguard Worker # output_process_fn: a function that maps the outputs to the part 2552*da0073e9SAndroid Build Coastguard Worker # that should be differentiated. 2553*da0073e9SAndroid Build Coastguard Worker # batch_size: the batch dim size for the batched grad 2554*da0073e9SAndroid Build Coastguard Worker def _batched_grad_test( 2555*da0073e9SAndroid Build Coastguard Worker self, op, args, kwargs=None, output_process_fn=lambda x: x, batch_size=3 2556*da0073e9SAndroid Build Coastguard Worker ): 2557*da0073e9SAndroid Build Coastguard Worker if kwargs is None: 2558*da0073e9SAndroid Build Coastguard Worker kwargs = {} 2559*da0073e9SAndroid Build Coastguard Worker outputs = op(*args, **kwargs) 2560*da0073e9SAndroid Build Coastguard Worker outputs = differentiable(output_process_fn(outputs)) 2561*da0073e9SAndroid Build Coastguard Worker batched_vectors = tuple(construct_v(out, batch_size) for out in outputs) 2562*da0073e9SAndroid Build Coastguard Worker 2563*da0073e9SAndroid Build Coastguard Worker def vector_jacobian_product(*vectors): 2564*da0073e9SAndroid Build Coastguard Worker return torch.autograd.grad( 2565*da0073e9SAndroid Build Coastguard Worker outputs, differentiable(args), vectors, retain_graph=True 2566*da0073e9SAndroid Build Coastguard Worker ) 2567*da0073e9SAndroid Build Coastguard Worker 2568*da0073e9SAndroid Build Coastguard Worker self._vmap_test( 2569*da0073e9SAndroid Build Coastguard Worker vector_jacobian_product, batched_vectors, check_propagates_grad=False 2570*da0073e9SAndroid Build Coastguard Worker ) 2571*da0073e9SAndroid Build Coastguard Worker 2572*da0073e9SAndroid Build Coastguard Worker # Tests batched second grad computation of outputs = op(*args, **kwargs). 2573*da0073e9SAndroid Build Coastguard Worker # by comparing it to a sequential map+stack fallback. 2574*da0073e9SAndroid Build Coastguard Worker # 2575*da0073e9SAndroid Build Coastguard Worker # output_process_fn: a function that maps the outputs to the part 2576*da0073e9SAndroid Build Coastguard Worker # that should be differentiated. 2577*da0073e9SAndroid Build Coastguard Worker # batch_size: the batch dim size for the batched grad 2578*da0073e9SAndroid Build Coastguard Worker # 2579*da0073e9SAndroid Build Coastguard Worker # NB: we only test computing batched gradients in the second gradient 2580*da0073e9SAndroid Build Coastguard Worker # computation. One specific use case that does this is computing the hessian 2581*da0073e9SAndroid Build Coastguard Worker # matrix of a scalar-valued function; this is useful in Bayesian Logistic 2582*da0073e9SAndroid Build Coastguard Worker # Regression. 2583*da0073e9SAndroid Build Coastguard Worker # It might be useful to have a test that computes batched first gradients and 2584*da0073e9SAndroid Build Coastguard Worker # then uses those to compute batched second gradients in the future. 2585*da0073e9SAndroid Build Coastguard Worker def _batched_grad_grad_test( 2586*da0073e9SAndroid Build Coastguard Worker self, op, args, kwargs=None, output_process_fn=lambda x: x, batch_size=3 2587*da0073e9SAndroid Build Coastguard Worker ): 2588*da0073e9SAndroid Build Coastguard Worker if kwargs is None: 2589*da0073e9SAndroid Build Coastguard Worker kwargs = {} 2590*da0073e9SAndroid Build Coastguard Worker outputs = op(*args, **kwargs) 2591*da0073e9SAndroid Build Coastguard Worker outputs = differentiable(output_process_fn(outputs)) 2592*da0073e9SAndroid Build Coastguard Worker ones = tuple(torch.ones_like(out) for out in outputs) 2593*da0073e9SAndroid Build Coastguard Worker # Same thing as summing together all of the outputs and calling .backward() 2594*da0073e9SAndroid Build Coastguard Worker first_grads = torch.autograd.grad( 2595*da0073e9SAndroid Build Coastguard Worker outputs, differentiable(args), ones, create_graph=True 2596*da0073e9SAndroid Build Coastguard Worker ) 2597*da0073e9SAndroid Build Coastguard Worker first_grads = differentiable(first_grads) 2598*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual( 2599*da0073e9SAndroid Build Coastguard Worker len(first_grads), 0, "None of the first grads depend on the input!" 2600*da0073e9SAndroid Build Coastguard Worker ) 2601*da0073e9SAndroid Build Coastguard Worker 2602*da0073e9SAndroid Build Coastguard Worker batched_vectors = tuple(construct_v(grad, batch_size) for grad in first_grads) 2603*da0073e9SAndroid Build Coastguard Worker 2604*da0073e9SAndroid Build Coastguard Worker def vector_hessian_product(*vectors): 2605*da0073e9SAndroid Build Coastguard Worker outputs = torch.autograd.grad( 2606*da0073e9SAndroid Build Coastguard Worker first_grads, 2607*da0073e9SAndroid Build Coastguard Worker differentiable(args), 2608*da0073e9SAndroid Build Coastguard Worker vectors, 2609*da0073e9SAndroid Build Coastguard Worker retain_graph=True, 2610*da0073e9SAndroid Build Coastguard Worker allow_unused=True, 2611*da0073e9SAndroid Build Coastguard Worker ) 2612*da0073e9SAndroid Build Coastguard Worker outputs = tuple(out for out in outputs if out is not None) 2613*da0073e9SAndroid Build Coastguard Worker assert len(outputs) > 0 2614*da0073e9SAndroid Build Coastguard Worker return outputs 2615*da0073e9SAndroid Build Coastguard Worker 2616*da0073e9SAndroid Build Coastguard Worker self._vmap_test( 2617*da0073e9SAndroid Build Coastguard Worker vector_hessian_product, batched_vectors, check_propagates_grad=False 2618*da0073e9SAndroid Build Coastguard Worker ) 2619*da0073e9SAndroid Build Coastguard Worker 2620*da0073e9SAndroid Build Coastguard Worker def _test_arithmetic(self, op, device, test_grad_grad=True): 2621*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 3, requires_grad=True, device=device) 2622*da0073e9SAndroid Build Coastguard Worker y = _get_rand_no_zeros(2, 3, device=device, requires_grad=True) 2623*da0073e9SAndroid Build Coastguard Worker scalar = 3.14 2624*da0073e9SAndroid Build Coastguard Worker self._batched_grad_test(op, (x, y)) 2625*da0073e9SAndroid Build Coastguard Worker self._batched_grad_test(op, (scalar, y)) 2626*da0073e9SAndroid Build Coastguard Worker self._batched_grad_test(op, (x, scalar)) 2627*da0073e9SAndroid Build Coastguard Worker 2628*da0073e9SAndroid Build Coastguard Worker if test_grad_grad: 2629*da0073e9SAndroid Build Coastguard Worker self._batched_grad_grad_test(op, (x, y)) 2630*da0073e9SAndroid Build Coastguard Worker 2631*da0073e9SAndroid Build Coastguard Worker def test_add(self, device): 2632*da0073e9SAndroid Build Coastguard Worker self._test_arithmetic(torch.add, device, test_grad_grad=False) 2633*da0073e9SAndroid Build Coastguard Worker self._test_arithmetic(lambda x, y: x + y, device, test_grad_grad=False) 2634*da0073e9SAndroid Build Coastguard Worker 2635*da0073e9SAndroid Build Coastguard Worker def test_sub(self, device): 2636*da0073e9SAndroid Build Coastguard Worker self._test_arithmetic(torch.sub, device, test_grad_grad=False) 2637*da0073e9SAndroid Build Coastguard Worker self._test_arithmetic(lambda x, y: x - y, device, test_grad_grad=False) 2638*da0073e9SAndroid Build Coastguard Worker 2639*da0073e9SAndroid Build Coastguard Worker def test_mul(self, device): 2640*da0073e9SAndroid Build Coastguard Worker self._test_arithmetic(torch.mul, device) 2641*da0073e9SAndroid Build Coastguard Worker self._test_arithmetic(lambda x, y: x * y, device) 2642*da0073e9SAndroid Build Coastguard Worker 2643*da0073e9SAndroid Build Coastguard Worker def test_div(self, device): 2644*da0073e9SAndroid Build Coastguard Worker self._test_arithmetic(torch.div, device) 2645*da0073e9SAndroid Build Coastguard Worker self._test_arithmetic(lambda x, y: x / y, device) 2646*da0073e9SAndroid Build Coastguard Worker 2647*da0073e9SAndroid Build Coastguard Worker @allowVmapFallbackUsage 2648*da0073e9SAndroid Build Coastguard Worker def test_binary_cross_entropy(self, device): 2649*da0073e9SAndroid Build Coastguard Worker x = torch.sigmoid(torch.randn(3, 2, device=device, requires_grad=True)) 2650*da0073e9SAndroid Build Coastguard Worker target = torch.rand(3, 2, device=device) 2651*da0073e9SAndroid Build Coastguard Worker 2652*da0073e9SAndroid Build Coastguard Worker op = functools.partial(F.binary_cross_entropy, target=target) 2653*da0073e9SAndroid Build Coastguard Worker 2654*da0073e9SAndroid Build Coastguard Worker self._batched_grad_test(op, (x,), {}) 2655*da0073e9SAndroid Build Coastguard Worker self._batched_grad_grad_test(op, (x,), {}) 2656*da0073e9SAndroid Build Coastguard Worker 2657*da0073e9SAndroid Build Coastguard Worker def test_expand(self, device): 2658*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 3, device=device, requires_grad=True) 2659*da0073e9SAndroid Build Coastguard Worker 2660*da0073e9SAndroid Build Coastguard Worker def op(x): 2661*da0073e9SAndroid Build Coastguard Worker return x.expand(5, 5, 2, 3) 2662*da0073e9SAndroid Build Coastguard Worker 2663*da0073e9SAndroid Build Coastguard Worker self._batched_grad_test(op, (x,)) 2664*da0073e9SAndroid Build Coastguard Worker 2665*da0073e9SAndroid Build Coastguard Worker @allowVmapFallbackUsage 2666*da0073e9SAndroid Build Coastguard Worker def test_index(self, device): 2667*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 3, requires_grad=True, device=device) 2668*da0073e9SAndroid Build Coastguard Worker index = torch.tensor([[0, 0], [1, 1]], device=device) 2669*da0073e9SAndroid Build Coastguard Worker 2670*da0073e9SAndroid Build Coastguard Worker def op(x): 2671*da0073e9SAndroid Build Coastguard Worker y = x * x 2672*da0073e9SAndroid Build Coastguard Worker return y[index] 2673*da0073e9SAndroid Build Coastguard Worker 2674*da0073e9SAndroid Build Coastguard Worker self._batched_grad_test(op, (x,)) 2675*da0073e9SAndroid Build Coastguard Worker self._batched_grad_grad_test(op, (x,)) 2676*da0073e9SAndroid Build Coastguard Worker 2677*da0073e9SAndroid Build Coastguard Worker def test_lgamma(self, device): 2678*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 3, requires_grad=True, device=device) 2679*da0073e9SAndroid Build Coastguard Worker self._batched_grad_test(Tensor.lgamma, (x,)) 2680*da0073e9SAndroid Build Coastguard Worker self._batched_grad_grad_test(Tensor.lgamma, (x,)) 2681*da0073e9SAndroid Build Coastguard Worker 2682*da0073e9SAndroid Build Coastguard Worker def test_log(self, device): 2683*da0073e9SAndroid Build Coastguard Worker x = _get_rand_no_zeros(2, 3, device=device, requires_grad=True) 2684*da0073e9SAndroid Build Coastguard Worker self._batched_grad_test(torch.log, (x,)) 2685*da0073e9SAndroid Build Coastguard Worker self._batched_grad_grad_test(torch.log, (x,)) 2686*da0073e9SAndroid Build Coastguard Worker 2687*da0073e9SAndroid Build Coastguard Worker def test_logsumexp(self, device): 2688*da0073e9SAndroid Build Coastguard Worker x = _get_rand_no_zeros(2, 3, device=device, requires_grad=True) 2689*da0073e9SAndroid Build Coastguard Worker 2690*da0073e9SAndroid Build Coastguard Worker def op(x): 2691*da0073e9SAndroid Build Coastguard Worker return torch.logsumexp(x, -1) 2692*da0073e9SAndroid Build Coastguard Worker 2693*da0073e9SAndroid Build Coastguard Worker self._batched_grad_test(op, (x,)) 2694*da0073e9SAndroid Build Coastguard Worker self._batched_grad_grad_test(op, (x,)) 2695*da0073e9SAndroid Build Coastguard Worker 2696*da0073e9SAndroid Build Coastguard Worker def test_log1p(self, device): 2697*da0073e9SAndroid Build Coastguard Worker x = _get_rand_no_zeros(2, 3, device=device, requires_grad=True) 2698*da0073e9SAndroid Build Coastguard Worker self._batched_grad_test(torch.log1p, (x,)) 2699*da0073e9SAndroid Build Coastguard Worker self._batched_grad_grad_test(torch.log1p, (x,)) 2700*da0073e9SAndroid Build Coastguard Worker 2701*da0073e9SAndroid Build Coastguard Worker @allowVmapFallbackUsage 2702*da0073e9SAndroid Build Coastguard Worker def test_max(self, device): 2703*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 3, requires_grad=True, device=device) 2704*da0073e9SAndroid Build Coastguard Worker self._batched_grad_test(torch.max, (x,)) 2705*da0073e9SAndroid Build Coastguard Worker 2706*da0073e9SAndroid Build Coastguard Worker @allowVmapFallbackUsage 2707*da0073e9SAndroid Build Coastguard Worker def test_median(self, device): 2708*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 3, requires_grad=True, device=device) 2709*da0073e9SAndroid Build Coastguard Worker self._batched_grad_test(torch.median, (x,)) 2710*da0073e9SAndroid Build Coastguard Worker 2711*da0073e9SAndroid Build Coastguard Worker @allowVmapFallbackUsage 2712*da0073e9SAndroid Build Coastguard Worker def test_min(self, device): 2713*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 3, requires_grad=True, device=device) 2714*da0073e9SAndroid Build Coastguard Worker self._batched_grad_test(torch.min, (x,)) 2715*da0073e9SAndroid Build Coastguard Worker 2716*da0073e9SAndroid Build Coastguard Worker def test_permute(self, device): 2717*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 3, 5, requires_grad=True, device=device) 2718*da0073e9SAndroid Build Coastguard Worker 2719*da0073e9SAndroid Build Coastguard Worker def op(x): 2720*da0073e9SAndroid Build Coastguard Worker return x.permute(2, 0, 1) 2721*da0073e9SAndroid Build Coastguard Worker 2722*da0073e9SAndroid Build Coastguard Worker self._batched_grad_test(op, (x,)) 2723*da0073e9SAndroid Build Coastguard Worker 2724*da0073e9SAndroid Build Coastguard Worker def test_reshape(self, device): 2725*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 3, 5, requires_grad=True, device=device) 2726*da0073e9SAndroid Build Coastguard Worker 2727*da0073e9SAndroid Build Coastguard Worker def op(x): 2728*da0073e9SAndroid Build Coastguard Worker return x.reshape([2 * 3, 5]) 2729*da0073e9SAndroid Build Coastguard Worker 2730*da0073e9SAndroid Build Coastguard Worker self._batched_grad_test(op, (x,)) 2731*da0073e9SAndroid Build Coastguard Worker 2732*da0073e9SAndroid Build Coastguard Worker def test_sigmoid(self, device): 2733*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 3, requires_grad=True, device=device) 2734*da0073e9SAndroid Build Coastguard Worker self._batched_grad_test(Tensor.sigmoid, (x,)) 2735*da0073e9SAndroid Build Coastguard Worker self._batched_grad_grad_test(Tensor.sigmoid, (x,)) 2736*da0073e9SAndroid Build Coastguard Worker 2737*da0073e9SAndroid Build Coastguard Worker def test_stack(self, device): 2738*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 3, device=device, requires_grad=True) 2739*da0073e9SAndroid Build Coastguard Worker y = torch.randn(2, 3, device=device, requires_grad=True) 2740*da0073e9SAndroid Build Coastguard Worker 2741*da0073e9SAndroid Build Coastguard Worker def op(x, y): 2742*da0073e9SAndroid Build Coastguard Worker return torch.stack([x, y]) 2743*da0073e9SAndroid Build Coastguard Worker 2744*da0073e9SAndroid Build Coastguard Worker self._batched_grad_test(op, (x, y)) 2745*da0073e9SAndroid Build Coastguard Worker 2746*da0073e9SAndroid Build Coastguard Worker def test_select(self, device): 2747*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 3, device=device, requires_grad=True) 2748*da0073e9SAndroid Build Coastguard Worker self._batched_grad_test(lambda x: x[1], (x,)) 2749*da0073e9SAndroid Build Coastguard Worker self._batched_grad_test(lambda x: x.select(1, 2), (x,)) 2750*da0073e9SAndroid Build Coastguard Worker self._batched_grad_test(lambda x: x.select(-1, 0), (x,)) 2751*da0073e9SAndroid Build Coastguard Worker 2752*da0073e9SAndroid Build Coastguard Worker def test_slice(self, device): 2753*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 3, 5, device=device, requires_grad=True) 2754*da0073e9SAndroid Build Coastguard Worker self._batched_grad_test(lambda x: x[0:1], (x,)) 2755*da0073e9SAndroid Build Coastguard Worker self._batched_grad_test(lambda x: x[:, 1:3], (x,)) 2756*da0073e9SAndroid Build Coastguard Worker self._batched_grad_test(lambda x: x[..., 1:3], (x,)) 2757*da0073e9SAndroid Build Coastguard Worker 2758*da0073e9SAndroid Build Coastguard Worker def test_trace(self, device): 2759*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 3, device=device, requires_grad=True) 2760*da0073e9SAndroid Build Coastguard Worker self._batched_grad_test(Tensor.trace, (x,)) 2761*da0073e9SAndroid Build Coastguard Worker 2762*da0073e9SAndroid Build Coastguard Worker def test_threshold(self, device): 2763*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 3, device=device, requires_grad=True) 2764*da0073e9SAndroid Build Coastguard Worker self._batched_grad_test(lambda x: F.threshold(x, 0.5, 0.0), (x,)) 2765*da0073e9SAndroid Build Coastguard Worker 2766*da0073e9SAndroid Build Coastguard Worker @allowVmapFallbackUsage 2767*da0073e9SAndroid Build Coastguard Worker def test_inplace_on_view(self, device): 2768*da0073e9SAndroid Build Coastguard Worker leaf = torch.randn(4, 5, requires_grad=True) 2769*da0073e9SAndroid Build Coastguard Worker 2770*da0073e9SAndroid Build Coastguard Worker def func(leaf): 2771*da0073e9SAndroid Build Coastguard Worker # Make sure the function is non-trivially twice differentiable 2772*da0073e9SAndroid Build Coastguard Worker base = leaf * leaf 2773*da0073e9SAndroid Build Coastguard Worker view = base[0] 2774*da0073e9SAndroid Build Coastguard Worker view.cos_() 2775*da0073e9SAndroid Build Coastguard Worker return view 2776*da0073e9SAndroid Build Coastguard Worker 2777*da0073e9SAndroid Build Coastguard Worker self._batched_grad_test(func, (leaf,), {}) 2778*da0073e9SAndroid Build Coastguard Worker self._batched_grad_grad_test(func, (leaf,), {}) 2779*da0073e9SAndroid Build Coastguard Worker 2780*da0073e9SAndroid Build Coastguard Worker @allowVmapFallbackUsage 2781*da0073e9SAndroid Build Coastguard Worker def test_inplace_manyview(self, device): 2782*da0073e9SAndroid Build Coastguard Worker leaf = torch.randn(4, 4, 5, requires_grad=True) 2783*da0073e9SAndroid Build Coastguard Worker 2784*da0073e9SAndroid Build Coastguard Worker def func(leaf): 2785*da0073e9SAndroid Build Coastguard Worker # Make sure the function is non-trivially twice differentiable 2786*da0073e9SAndroid Build Coastguard Worker base = leaf * leaf 2787*da0073e9SAndroid Build Coastguard Worker view = base.transpose(0, 2) 2788*da0073e9SAndroid Build Coastguard Worker view = view[1] 2789*da0073e9SAndroid Build Coastguard Worker view = view.diagonal() 2790*da0073e9SAndroid Build Coastguard Worker view = view[::2] 2791*da0073e9SAndroid Build Coastguard Worker view.cos_() 2792*da0073e9SAndroid Build Coastguard Worker return view 2793*da0073e9SAndroid Build Coastguard Worker 2794*da0073e9SAndroid Build Coastguard Worker self._batched_grad_test(func, (leaf,), {}) 2795*da0073e9SAndroid Build Coastguard Worker self._batched_grad_grad_test(func, (leaf,), {}) 2796*da0073e9SAndroid Build Coastguard Worker 2797*da0073e9SAndroid Build Coastguard Worker def test_diagonal(self, device): 2798*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4, 5, device=device, requires_grad=True) 2799*da0073e9SAndroid Build Coastguard Worker self._batched_grad_test(lambda x: x.diagonal(1, 0, 1), (x,)) 2800*da0073e9SAndroid Build Coastguard Worker 2801*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, 4, 5, device=device, requires_grad=True) 2802*da0073e9SAndroid Build Coastguard Worker self._batched_grad_test(lambda x: x.diagonal(0, -1, -2), (x,)) 2803*da0073e9SAndroid Build Coastguard Worker 2804*da0073e9SAndroid Build Coastguard Worker @allowVmapFallbackUsage 2805*da0073e9SAndroid Build Coastguard Worker def test_unrelated_output(self, device): 2806*da0073e9SAndroid Build Coastguard Worker B0 = 3 2807*da0073e9SAndroid Build Coastguard Worker x = torch.randn([], requires_grad=True) 2808*da0073e9SAndroid Build Coastguard Worker y = torch.randn([], requires_grad=True) 2809*da0073e9SAndroid Build Coastguard Worker gy = torch.randn(B0, requires_grad=True) 2810*da0073e9SAndroid Build Coastguard Worker 2811*da0073e9SAndroid Build Coastguard Worker def vjp(v): 2812*da0073e9SAndroid Build Coastguard Worker (res,) = torch.autograd.grad(y, x, v, allow_unused=True) 2813*da0073e9SAndroid Build Coastguard Worker return torch.zeros_like(x) if res is None else res 2814*da0073e9SAndroid Build Coastguard Worker 2815*da0073e9SAndroid Build Coastguard Worker result = vmap(vjp)(gy) 2816*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, torch.zeros(B0, *x.shape, device=device)) 2817*da0073e9SAndroid Build Coastguard Worker 2818*da0073e9SAndroid Build Coastguard Worker @allowVmapFallbackUsage 2819*da0073e9SAndroid Build Coastguard Worker def test_unrelated_output_multiple_grad(self, device): 2820*da0073e9SAndroid Build Coastguard Worker B0 = 3 2821*da0073e9SAndroid Build Coastguard Worker x = torch.randn([], requires_grad=True) 2822*da0073e9SAndroid Build Coastguard Worker y = torch.randn([], requires_grad=True) 2823*da0073e9SAndroid Build Coastguard Worker gy = torch.randn(B0, requires_grad=True) 2824*da0073e9SAndroid Build Coastguard Worker 2825*da0073e9SAndroid Build Coastguard Worker def vjp(v): 2826*da0073e9SAndroid Build Coastguard Worker (res,) = torch.autograd.grad(y, x, v, allow_unused=True) 2827*da0073e9SAndroid Build Coastguard Worker return torch.zeros_like(x) if res is None else res 2828*da0073e9SAndroid Build Coastguard Worker 2829*da0073e9SAndroid Build Coastguard Worker _ = vjp(gy[0]) 2830*da0073e9SAndroid Build Coastguard Worker result = vmap(vjp)(gy) 2831*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, torch.zeros(B0, *x.shape, device=device)) 2832*da0073e9SAndroid Build Coastguard Worker 2833*da0073e9SAndroid Build Coastguard Worker 2834*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestVmapBatchedGradientLegacy, globals(), None) 2835*da0073e9SAndroid Build Coastguard Worker 2836*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 2837*da0073e9SAndroid Build Coastguard Worker run_tests() 2838