xref: /aosp_15_r20/external/pytorch/test/test_legacy_vmap.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: vmap"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport functools
4*da0073e9SAndroid Build Coastguard Workerimport itertools
5*da0073e9SAndroid Build Coastguard Workerimport types
6*da0073e9SAndroid Build Coastguard Workerimport warnings
7*da0073e9SAndroid Build Coastguard Worker
8*da0073e9SAndroid Build Coastguard Workerimport torch
9*da0073e9SAndroid Build Coastguard Workerimport torch.nn.functional as F
10*da0073e9SAndroid Build Coastguard Workerfrom torch import Tensor
11*da0073e9SAndroid Build Coastguard Workerfrom torch._vmap_internals import vmap
12*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import instantiate_device_type_tests
13*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, TestCase
14*da0073e9SAndroid Build Coastguard Worker
15*da0073e9SAndroid Build Coastguard Worker
16*da0073e9SAndroid Build Coastguard WorkerFALLBACK_REGEX = r"There is a performance drop"
17*da0073e9SAndroid Build Coastguard Worker
18*da0073e9SAndroid Build Coastguard Worker
19*da0073e9SAndroid Build Coastguard Workerclass EnableVmapFallbackWarnings:
20*da0073e9SAndroid Build Coastguard Worker    def __enter__(self):
21*da0073e9SAndroid Build Coastguard Worker        self.prev_state = torch._C._debug_only_are_vmap_fallback_warnings_enabled()
22*da0073e9SAndroid Build Coastguard Worker        torch._C._debug_only_display_vmap_fallback_warnings(True)
23*da0073e9SAndroid Build Coastguard Worker
24*da0073e9SAndroid Build Coastguard Worker    def __exit__(self, *ignored):
25*da0073e9SAndroid Build Coastguard Worker        torch._C._debug_only_display_vmap_fallback_warnings(self.prev_state)
26*da0073e9SAndroid Build Coastguard Worker
27*da0073e9SAndroid Build Coastguard Worker
28*da0073e9SAndroid Build Coastguard Workerclass TestVmapAPILegacy(TestCase):
29*da0073e9SAndroid Build Coastguard Worker    def test_non_tensor_output_raises(self):
30*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
31*da0073e9SAndroid Build Coastguard Worker            ValueError, "got type <class 'float'> as the return"
32*da0073e9SAndroid Build Coastguard Worker        ):
33*da0073e9SAndroid Build Coastguard Worker            output = vmap(lambda x: 3.14)(torch.ones(3))
34*da0073e9SAndroid Build Coastguard Worker
35*da0073e9SAndroid Build Coastguard Worker        def multiple_outputs(x):
36*da0073e9SAndroid Build Coastguard Worker            return x, 3
37*da0073e9SAndroid Build Coastguard Worker
38*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "got type <class 'int'> for return 1"):
39*da0073e9SAndroid Build Coastguard Worker            vmap(multiple_outputs)(torch.ones(3))
40*da0073e9SAndroid Build Coastguard Worker
41*da0073e9SAndroid Build Coastguard Worker    def test_different_map_dim_size_raises(self):
42*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2)
43*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(3)
44*da0073e9SAndroid Build Coastguard Worker        expected_msg = (
45*da0073e9SAndroid Build Coastguard Worker            "Expected all tensors to have the same size in the mapped dimension"
46*da0073e9SAndroid Build Coastguard Worker        )
47*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, expected_msg):
48*da0073e9SAndroid Build Coastguard Worker            vmap(torch.mul)(x, y)
49*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, expected_msg):
50*da0073e9SAndroid Build Coastguard Worker            vmap(lambda z: z[0] + z[1], in_dims=((0, 0),))((x, y))
51*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, expected_msg):
52*da0073e9SAndroid Build Coastguard Worker            vmap(lambda z: z["x"] + z["y"], in_dims=({"x": 0, "y": 0},))(
53*da0073e9SAndroid Build Coastguard Worker                {"x": x, "y": y}
54*da0073e9SAndroid Build Coastguard Worker            )
55*da0073e9SAndroid Build Coastguard Worker
56*da0073e9SAndroid Build Coastguard Worker    def test_func_with_no_inputs(self):
57*da0073e9SAndroid Build Coastguard Worker        expected_msg = "got no inputs"
58*da0073e9SAndroid Build Coastguard Worker
59*da0073e9SAndroid Build Coastguard Worker        def foo():
60*da0073e9SAndroid Build Coastguard Worker            return torch.randn(3)
61*da0073e9SAndroid Build Coastguard Worker
62*da0073e9SAndroid Build Coastguard Worker        def bar(x):
63*da0073e9SAndroid Build Coastguard Worker            return torch.randn(3)
64*da0073e9SAndroid Build Coastguard Worker
65*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, expected_msg):
66*da0073e9SAndroid Build Coastguard Worker            vmap(foo)()
67*da0073e9SAndroid Build Coastguard Worker
68*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, expected_msg):
69*da0073e9SAndroid Build Coastguard Worker            vmap(bar)()
70*da0073e9SAndroid Build Coastguard Worker
71*da0073e9SAndroid Build Coastguard Worker    def test_constant_function(self):
72*da0073e9SAndroid Build Coastguard Worker        output = vmap(lambda x: torch.tensor(3.14))(torch.ones(3))
73*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output, torch.tensor([3.14, 3.14, 3.14]))
74*da0073e9SAndroid Build Coastguard Worker
75*da0073e9SAndroid Build Coastguard Worker    def test_single_input(self):
76*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 3)
77*da0073e9SAndroid Build Coastguard Worker
78*da0073e9SAndroid Build Coastguard Worker        def square(x):
79*da0073e9SAndroid Build Coastguard Worker            return x * x
80*da0073e9SAndroid Build Coastguard Worker
81*da0073e9SAndroid Build Coastguard Worker        output = vmap(square)(x)
82*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output, x * x)
83*da0073e9SAndroid Build Coastguard Worker
84*da0073e9SAndroid Build Coastguard Worker    def test_multiple_inputs(self):
85*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 3)
86*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(2, 3)
87*da0073e9SAndroid Build Coastguard Worker        output = vmap(torch.mul)(x, y)
88*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output, x * y)
89*da0073e9SAndroid Build Coastguard Worker
90*da0073e9SAndroid Build Coastguard Worker    def test_multiple_outputs(self):
91*da0073e9SAndroid Build Coastguard Worker        def foo(x):
92*da0073e9SAndroid Build Coastguard Worker            return x * x, x * x * x
93*da0073e9SAndroid Build Coastguard Worker
94*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3)
95*da0073e9SAndroid Build Coastguard Worker        outputs = vmap(foo)(x)
96*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(outputs[0], x * x)
97*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(outputs[1], x * x * x)
98*da0073e9SAndroid Build Coastguard Worker
99*da0073e9SAndroid Build Coastguard Worker    def test_multiple_outputs_error_cases(self):
100*da0073e9SAndroid Build Coastguard Worker        # This is the same thing as
101*da0073e9SAndroid Build Coastguard Worker        # def returns_tuple_of_tensors(x):
102*da0073e9SAndroid Build Coastguard Worker        #     return x, x
103*da0073e9SAndroid Build Coastguard Worker        def returns_tuple_of_tensors(x):
104*da0073e9SAndroid Build Coastguard Worker            return (x, x)
105*da0073e9SAndroid Build Coastguard Worker
106*da0073e9SAndroid Build Coastguard Worker        def returns_list_of_two_tensors(x):
107*da0073e9SAndroid Build Coastguard Worker            return [x, x]
108*da0073e9SAndroid Build Coastguard Worker
109*da0073e9SAndroid Build Coastguard Worker        def returns_list_of_one_tensor(x):
110*da0073e9SAndroid Build Coastguard Worker            return [x]
111*da0073e9SAndroid Build Coastguard Worker
112*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3)
113*da0073e9SAndroid Build Coastguard Worker
114*da0073e9SAndroid Build Coastguard Worker        # should not throw
115*da0073e9SAndroid Build Coastguard Worker        vmap(returns_tuple_of_tensors)(x)
116*da0073e9SAndroid Build Coastguard Worker
117*da0073e9SAndroid Build Coastguard Worker        # jax supports these, but we don't yet
118*da0073e9SAndroid Build Coastguard Worker        msg = "must only return Tensors, got type <class 'list'>"
119*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, msg):
120*da0073e9SAndroid Build Coastguard Worker            vmap(returns_list_of_two_tensors)(x)
121*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, msg):
122*da0073e9SAndroid Build Coastguard Worker            vmap(returns_list_of_one_tensor)(x)
123*da0073e9SAndroid Build Coastguard Worker
124*da0073e9SAndroid Build Coastguard Worker    def test_nested_with_same_map_dim(self):
125*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 3, 5)
126*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(2, 3, 5)
127*da0073e9SAndroid Build Coastguard Worker        output = vmap(vmap(torch.mul))(x, y)
128*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output, x * y)
129*da0073e9SAndroid Build Coastguard Worker
130*da0073e9SAndroid Build Coastguard Worker        output = vmap(vmap(vmap(torch.mul)))(x, y)
131*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output, x * y)
132*da0073e9SAndroid Build Coastguard Worker
133*da0073e9SAndroid Build Coastguard Worker    def test_nested_with_different_map_dim(self):
134*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 3)
135*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(5, 3)
136*da0073e9SAndroid Build Coastguard Worker        output = vmap(lambda x: vmap(lambda y: x * y)(y))(x)
137*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output.shape, (2, 5, 3))
138*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output, x.view(2, 1, 3) * y)
139*da0073e9SAndroid Build Coastguard Worker
140*da0073e9SAndroid Build Coastguard Worker        z = torch.randn(7, 3)
141*da0073e9SAndroid Build Coastguard Worker        output = vmap(lambda x: vmap(lambda y: vmap(lambda z: x * y * z)(z))(y))(x)
142*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output.shape, (2, 5, 7, 3))
143*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output, x.view(2, 1, 1, 3) * y.view(5, 1, 3) * z)
144*da0073e9SAndroid Build Coastguard Worker
145*da0073e9SAndroid Build Coastguard Worker    def test_noop_in_inner_vmap(self):
146*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3)
147*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(5)
148*da0073e9SAndroid Build Coastguard Worker        output = vmap(lambda x: vmap(lambda y: x)(y))(x)
149*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output, x.view(3, 1).expand(3, 5))
150*da0073e9SAndroid Build Coastguard Worker
151*da0073e9SAndroid Build Coastguard Worker    def test_unsupported_op_err_msg(self):
152*da0073e9SAndroid Build Coastguard Worker        # Unsupported view op
153*da0073e9SAndroid Build Coastguard Worker        tensor = torch.randn(2, 3)
154*da0073e9SAndroid Build Coastguard Worker        msg = (
155*da0073e9SAndroid Build Coastguard Worker            r"Batching rule not implemented for aten::.+; the "
156*da0073e9SAndroid Build Coastguard Worker            r"fallback path doesn't work on out= or view ops"
157*da0073e9SAndroid Build Coastguard Worker        )
158*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, msg):
159*da0073e9SAndroid Build Coastguard Worker            vmap(torch.ravel)(tensor)
160*da0073e9SAndroid Build Coastguard Worker
161*da0073e9SAndroid Build Coastguard Worker        def out_op(x, y):
162*da0073e9SAndroid Build Coastguard Worker            return torch.abs(x, out=y)
163*da0073e9SAndroid Build Coastguard Worker
164*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, msg):
165*da0073e9SAndroid Build Coastguard Worker            vmap(out_op)(tensor, tensor)
166*da0073e9SAndroid Build Coastguard Worker
167*da0073e9SAndroid Build Coastguard Worker        tensor = torch.randn(2)
168*da0073e9SAndroid Build Coastguard Worker        # The fallback doesn't support TensorList
169*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Batching rule not implemented"):
170*da0073e9SAndroid Build Coastguard Worker            vmap(lambda t: torch.atleast_1d([t]))(tensor)
171*da0073e9SAndroid Build Coastguard Worker
172*da0073e9SAndroid Build Coastguard Worker        # Don't support non-tensor returns. This is a limitation of vmap;
173*da0073e9SAndroid Build Coastguard Worker        # functions that don't return tensors must be special cased
174*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Batching rule not implemented"):
175*da0073e9SAndroid Build Coastguard Worker            vmap(torch.Tensor.item)(tensor)
176*da0073e9SAndroid Build Coastguard Worker
177*da0073e9SAndroid Build Coastguard Worker    def test_nonzero_out_dims(self):
178*da0073e9SAndroid Build Coastguard Worker        # Basic test
179*da0073e9SAndroid Build Coastguard Worker        tensor = torch.randn(2, 3)
180*da0073e9SAndroid Build Coastguard Worker        result = vmap(lambda x: x, out_dims=1)(tensor)
181*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result, tensor.permute(1, 0))
182*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result.data_ptr(), tensor.data_ptr())
183*da0073e9SAndroid Build Coastguard Worker
184*da0073e9SAndroid Build Coastguard Worker        # Test that the batch dimension gets permuted to dim 2
185*da0073e9SAndroid Build Coastguard Worker        tensor = torch.randn(2, 3, 5, 7)
186*da0073e9SAndroid Build Coastguard Worker        result = vmap(lambda x: x, out_dims=2)(tensor)
187*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result, tensor.permute(1, 2, 0, 3))
188*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result.data_ptr(), tensor.data_ptr())
189*da0073e9SAndroid Build Coastguard Worker
190*da0073e9SAndroid Build Coastguard Worker        # negative out_dim
191*da0073e9SAndroid Build Coastguard Worker        tensor = torch.randn(2, 3, 5, 7)
192*da0073e9SAndroid Build Coastguard Worker        result = vmap(lambda x: x, out_dims=-1)(tensor)
193*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result, tensor.permute(1, 2, 3, 0))
194*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result.data_ptr(), tensor.data_ptr())
195*da0073e9SAndroid Build Coastguard Worker
196*da0073e9SAndroid Build Coastguard Worker        # check that out_dims works on ALL outputs
197*da0073e9SAndroid Build Coastguard Worker        tensor = torch.randn(2, 3, 5, 7)
198*da0073e9SAndroid Build Coastguard Worker        other = torch.randn(2, 3, 5, 7)
199*da0073e9SAndroid Build Coastguard Worker        result = vmap(lambda x, y: (x, y), out_dims=2)(tensor, other)
200*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
201*da0073e9SAndroid Build Coastguard Worker            result, (tensor.permute(1, 2, 0, 3), other.permute(1, 2, 0, 3))
202*da0073e9SAndroid Build Coastguard Worker        )
203*da0073e9SAndroid Build Coastguard Worker
204*da0073e9SAndroid Build Coastguard Worker        # use out_dims with the maximum vmap-able tensor dims (64 dims)
205*da0073e9SAndroid Build Coastguard Worker        ndims = 64
206*da0073e9SAndroid Build Coastguard Worker        shape = [2] + [1] * (ndims - 1)
207*da0073e9SAndroid Build Coastguard Worker        expected_shape = [1, 1, 2] + [1] * (ndims - 3)
208*da0073e9SAndroid Build Coastguard Worker        tensor = torch.randn(shape)
209*da0073e9SAndroid Build Coastguard Worker        result = vmap(lambda x: x, out_dims=2)(tensor)
210*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result.shape, expected_shape)
211*da0073e9SAndroid Build Coastguard Worker
212*da0073e9SAndroid Build Coastguard Worker        # test something that is not the identity function
213*da0073e9SAndroid Build Coastguard Worker        def foo(x, y):
214*da0073e9SAndroid Build Coastguard Worker            return x, x * y, x * y * y
215*da0073e9SAndroid Build Coastguard Worker
216*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 3, 5)
217*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(2, 3, 5)
218*da0073e9SAndroid Build Coastguard Worker        result = vmap(foo, out_dims=1)(x, y)
219*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
220*da0073e9SAndroid Build Coastguard Worker            result,
221*da0073e9SAndroid Build Coastguard Worker            (
222*da0073e9SAndroid Build Coastguard Worker                x.permute(1, 0, 2),
223*da0073e9SAndroid Build Coastguard Worker                (x * y).permute(1, 0, 2),
224*da0073e9SAndroid Build Coastguard Worker                (x * y * y).permute(1, 0, 2),
225*da0073e9SAndroid Build Coastguard Worker            ),
226*da0073e9SAndroid Build Coastguard Worker        )
227*da0073e9SAndroid Build Coastguard Worker
228*da0073e9SAndroid Build Coastguard Worker    def test_multiple_out_dims(self):
229*da0073e9SAndroid Build Coastguard Worker        def foo(x):
230*da0073e9SAndroid Build Coastguard Worker            return x, x
231*da0073e9SAndroid Build Coastguard Worker
232*da0073e9SAndroid Build Coastguard Worker        def bar(x, y):
233*da0073e9SAndroid Build Coastguard Worker            return x, x, x, x * y
234*da0073e9SAndroid Build Coastguard Worker
235*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 3, 5)
236*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(2, 3, 5)
237*da0073e9SAndroid Build Coastguard Worker        result = vmap(foo, out_dims=(0, 1))(x)
238*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result, (x, x.permute(1, 0, 2)))
239*da0073e9SAndroid Build Coastguard Worker
240*da0073e9SAndroid Build Coastguard Worker        result = vmap(bar, out_dims=(-1, 0, 1, 2))(x, y)
241*da0073e9SAndroid Build Coastguard Worker        expected = (
242*da0073e9SAndroid Build Coastguard Worker            x.permute(1, 2, 0),
243*da0073e9SAndroid Build Coastguard Worker            x,
244*da0073e9SAndroid Build Coastguard Worker            x.permute(1, 0, 2),
245*da0073e9SAndroid Build Coastguard Worker            (x * y).permute(1, 2, 0),
246*da0073e9SAndroid Build Coastguard Worker        )
247*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result, expected)
248*da0073e9SAndroid Build Coastguard Worker
249*da0073e9SAndroid Build Coastguard Worker    def test_nested_out_dims(self):
250*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(2, 3, 5, 7)
251*da0073e9SAndroid Build Coastguard Worker
252*da0073e9SAndroid Build Coastguard Worker        # Inner vmap has non-zero out_dim
253*da0073e9SAndroid Build Coastguard Worker        result = vmap(lambda y: vmap(lambda x: x, out_dims=1)(y))(y)
254*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result.shape, (2, 5, 3, 7))
255*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result, y.permute(0, 2, 1, 3))
256*da0073e9SAndroid Build Coastguard Worker
257*da0073e9SAndroid Build Coastguard Worker        # all vmaps have non-zero out_dim
258*da0073e9SAndroid Build Coastguard Worker        result = vmap(lambda y: vmap(lambda x: x, out_dims=1)(y), out_dims=1)(y)
259*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result.shape, (5, 2, 3, 7))
260*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result, y.permute(2, 0, 1, 3))
261*da0073e9SAndroid Build Coastguard Worker
262*da0073e9SAndroid Build Coastguard Worker        # throwing in some negative out_dims
263*da0073e9SAndroid Build Coastguard Worker        result = vmap(lambda y: vmap(lambda x: x, out_dims=-1)(y), out_dims=-1)(y)
264*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result.shape, (5, 7, 3, 2))
265*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result, y.permute(2, 3, 1, 0))
266*da0073e9SAndroid Build Coastguard Worker
267*da0073e9SAndroid Build Coastguard Worker        # testing fn that isn't the identity
268*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 3)
269*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(5, 3)
270*da0073e9SAndroid Build Coastguard Worker        result = vmap(lambda y: vmap(lambda x: x * y, out_dims=1)(x), out_dims=-1)(y)
271*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result.shape, (3, 2, 5))
272*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result, (y.view(5, 1, 3) * x).permute(2, 1, 0))
273*da0073e9SAndroid Build Coastguard Worker
274*da0073e9SAndroid Build Coastguard Worker    def test_out_dims_edge_case(self):
275*da0073e9SAndroid Build Coastguard Worker        def foo(x):
276*da0073e9SAndroid Build Coastguard Worker            return x
277*da0073e9SAndroid Build Coastguard Worker
278*da0073e9SAndroid Build Coastguard Worker        # Test that we accept out_dims=(1,) for a function with one output.
279*da0073e9SAndroid Build Coastguard Worker        tensor = torch.randn(2, 3)
280*da0073e9SAndroid Build Coastguard Worker        expected = vmap(foo, out_dims=1)(tensor)
281*da0073e9SAndroid Build Coastguard Worker        result = vmap(foo, out_dims=(1,))(tensor)
282*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result, expected)
283*da0073e9SAndroid Build Coastguard Worker
284*da0073e9SAndroid Build Coastguard Worker    def test_out_dims_must_be_int_or_tuple_of_int_err_msg(self):
285*da0073e9SAndroid Build Coastguard Worker        msg = "`out_dims` must be an int or a tuple of int"
286*da0073e9SAndroid Build Coastguard Worker        tensor = torch.randn(2, 3)
287*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, msg):
288*da0073e9SAndroid Build Coastguard Worker            vmap(lambda x: x, out_dims="lol")(tensor)
289*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, msg):
290*da0073e9SAndroid Build Coastguard Worker            vmap(lambda x: x, out_dims=("lol",))(tensor)
291*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, msg):
292*da0073e9SAndroid Build Coastguard Worker            vmap(lambda x: x, out_dims=None)(tensor)
293*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, msg):
294*da0073e9SAndroid Build Coastguard Worker            vmap(lambda x: x, out_dims=(None,))(tensor)
295*da0073e9SAndroid Build Coastguard Worker
296*da0073e9SAndroid Build Coastguard Worker    def test_out_dims_and_num_outputs_mismatch_err_msg(self):
297*da0073e9SAndroid Build Coastguard Worker        msg = "`out_dims` must have one dim per output"
298*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 3, 5)
299*da0073e9SAndroid Build Coastguard Worker
300*da0073e9SAndroid Build Coastguard Worker        # Too many out_dims
301*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, msg):
302*da0073e9SAndroid Build Coastguard Worker            vmap(lambda x: x, out_dims=(0, 0))(x)
303*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, msg):
304*da0073e9SAndroid Build Coastguard Worker            vmap(lambda x: (x, x, x), out_dims=(0, 0, 0, 0))(x)
305*da0073e9SAndroid Build Coastguard Worker
306*da0073e9SAndroid Build Coastguard Worker        # Too few out_dims
307*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, msg):
308*da0073e9SAndroid Build Coastguard Worker            vmap(lambda x: (x, x), out_dims=(0,))(x)
309*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, msg):
310*da0073e9SAndroid Build Coastguard Worker            vmap(lambda x: (x, x, x), out_dims=(0, 0))(x)
311*da0073e9SAndroid Build Coastguard Worker
312*da0073e9SAndroid Build Coastguard Worker    def test_out_dim_out_of_bounds_err_msg(self):
313*da0073e9SAndroid Build Coastguard Worker        # TODO(rzou): This error message isn't that great. It comes straight
314*da0073e9SAndroid Build Coastguard Worker        # from maybe_wrap_dim. Consider doing a try-catch-(add some context) to
315*da0073e9SAndroid Build Coastguard Worker        # the error message in the future in C++
316*da0073e9SAndroid Build Coastguard Worker        msg = "Dimension out of range"
317*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 3, 5)
318*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(IndexError, msg):
319*da0073e9SAndroid Build Coastguard Worker            vmap(lambda x: x, out_dims=3)(x)
320*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(IndexError, msg):
321*da0073e9SAndroid Build Coastguard Worker            vmap(lambda x: x, out_dims=-4)(x)
322*da0073e9SAndroid Build Coastguard Worker
323*da0073e9SAndroid Build Coastguard Worker    def test_non_zero_in_dims(self):
324*da0073e9SAndroid Build Coastguard Worker        tensor = torch.randn(2, 3, 5)
325*da0073e9SAndroid Build Coastguard Worker
326*da0073e9SAndroid Build Coastguard Worker        # Implicit out_dims = 0; vmap will move the batch dim to the front.
327*da0073e9SAndroid Build Coastguard Worker        output = vmap(lambda x: x, (1,))(tensor)
328*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output, tensor.permute(1, 0, 2))
329*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output.data_ptr(), tensor.data_ptr())
330*da0073e9SAndroid Build Coastguard Worker
331*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 3)
332*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(3, 2)
333*da0073e9SAndroid Build Coastguard Worker        output = vmap(torch.mul, (0, 1))(x, y)
334*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output, x * y.t())
335*da0073e9SAndroid Build Coastguard Worker        output = vmap(torch.mul, (1, 0))(x, y)
336*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output, x.t() * y)
337*da0073e9SAndroid Build Coastguard Worker
338*da0073e9SAndroid Build Coastguard Worker    def test_none_in_dims(self):
339*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 3)
340*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(2, 3)
341*da0073e9SAndroid Build Coastguard Worker
342*da0073e9SAndroid Build Coastguard Worker        # None in_dim for a Tensor means we don't map over it
343*da0073e9SAndroid Build Coastguard Worker        output = vmap(torch.mul, (0, None))(x, y)
344*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output.shape, (2, 2, 3))
345*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output, x.view(2, 1, 3) * y)
346*da0073e9SAndroid Build Coastguard Worker
347*da0073e9SAndroid Build Coastguard Worker        # None in_dim for non-tensor arguments
348*da0073e9SAndroid Build Coastguard Worker        output = vmap(torch.mul, (0, None))(x, 2)
349*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output, x * 2)
350*da0073e9SAndroid Build Coastguard Worker
351*da0073e9SAndroid Build Coastguard Worker    def test_nested_non_default_in_dims(self):
352*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(5, 2, 3)
353*da0073e9SAndroid Build Coastguard Worker        y = torch.rand(3, 5, 2)
354*da0073e9SAndroid Build Coastguard Worker        result = vmap(vmap(vmap(torch.mul), (1, 0)), (1, 2))(x, y)
355*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result, x.permute(1, 2, 0) * y.permute(2, 0, 1))
356*da0073e9SAndroid Build Coastguard Worker
357*da0073e9SAndroid Build Coastguard Worker    def test_non_default_in_dims_out_dims(self):
358*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 3, 5)
359*da0073e9SAndroid Build Coastguard Worker
360*da0073e9SAndroid Build Coastguard Worker        # Same in_dim as out_dim, vmap over identity
361*da0073e9SAndroid Build Coastguard Worker        result = vmap(lambda x: x, in_dims=1, out_dims=1)(x)
362*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result, x)
363*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result.data_ptr(), x.data_ptr())
364*da0073e9SAndroid Build Coastguard Worker
365*da0073e9SAndroid Build Coastguard Worker        # Different in_dim from out_dim, vmap over identity
366*da0073e9SAndroid Build Coastguard Worker        result = vmap(lambda x: x, in_dims=2, out_dims=1)(x)
367*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result.shape, (2, 5, 3))
368*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result, x.transpose(1, 2))
369*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result.data_ptr(), x.data_ptr())
370*da0073e9SAndroid Build Coastguard Worker
371*da0073e9SAndroid Build Coastguard Worker        def foo(x):
372*da0073e9SAndroid Build Coastguard Worker            return x * 2
373*da0073e9SAndroid Build Coastguard Worker
374*da0073e9SAndroid Build Coastguard Worker        # Same in_dim as out_dim, vmap over operation
375*da0073e9SAndroid Build Coastguard Worker        result = vmap(foo, in_dims=1, out_dims=1)(x)
376*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result, x * 2)
377*da0073e9SAndroid Build Coastguard Worker
378*da0073e9SAndroid Build Coastguard Worker        # Different in_dim as out_dim, vmap over operation
379*da0073e9SAndroid Build Coastguard Worker        result = vmap(foo, in_dims=2, out_dims=1)(x)
380*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result.shape, (2, 5, 3))
381*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result, (x * 2).transpose(1, 2))
382*da0073e9SAndroid Build Coastguard Worker
383*da0073e9SAndroid Build Coastguard Worker        # Basic nested test.
384*da0073e9SAndroid Build Coastguard Worker        result = vmap(vmap(foo, 1, 1), 1, 1)(x)
385*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result, x * 2)
386*da0073e9SAndroid Build Coastguard Worker
387*da0073e9SAndroid Build Coastguard Worker    def test_accepts_nested_inputs(self):
388*da0073e9SAndroid Build Coastguard Worker        B0 = 2
389*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 3)
390*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(2, 3)
391*da0073e9SAndroid Build Coastguard Worker
392*da0073e9SAndroid Build Coastguard Worker        # Single layer of nesting
393*da0073e9SAndroid Build Coastguard Worker        out = vmap(lambda z: z[0] + z[1])((x, y))
394*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out, x + y)
395*da0073e9SAndroid Build Coastguard Worker        out = vmap(lambda z: z[0] + z[1], in_dims=(0,))((x, y))
396*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out, x + y)
397*da0073e9SAndroid Build Coastguard Worker        out = vmap(lambda z: z[0] + z[1], in_dims=((0, 0),))((x, y))
398*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out, x + y)
399*da0073e9SAndroid Build Coastguard Worker
400*da0073e9SAndroid Build Coastguard Worker        out = vmap(lambda z: z[0] + z[1])([x, y])
401*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out, x + y)
402*da0073e9SAndroid Build Coastguard Worker        out = vmap(lambda z: z[0] + z[1], in_dims=(0,))([x, y])
403*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out, x + y)
404*da0073e9SAndroid Build Coastguard Worker        out = vmap(lambda z: z[0] + z[1], in_dims=([0, 0],))([x, y])
405*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out, x + y)
406*da0073e9SAndroid Build Coastguard Worker
407*da0073e9SAndroid Build Coastguard Worker        out = vmap(lambda z: z["x"] + z["y"])({"x": x, "y": y})
408*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out, x + y)
409*da0073e9SAndroid Build Coastguard Worker        out = vmap(lambda z: z["x"] + z["y"], in_dims=(0,))({"x": x, "y": y})
410*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out, x + y)
411*da0073e9SAndroid Build Coastguard Worker        out = vmap(lambda z: z["x"] + z["y"], in_dims=({"x": 0, "y": 0},))(
412*da0073e9SAndroid Build Coastguard Worker            {"x": x, "y": y}
413*da0073e9SAndroid Build Coastguard Worker        )
414*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out, x + y)
415*da0073e9SAndroid Build Coastguard Worker
416*da0073e9SAndroid Build Coastguard Worker        # Multiple layers of nesting
417*da0073e9SAndroid Build Coastguard Worker        out_fn = vmap(lambda z: z["x"][0] + z["x"][1][0] + z["y"][0] + z["y"][1])
418*da0073e9SAndroid Build Coastguard Worker        out = out_fn({"x": [x, (x,)], "y": [y, y]})
419*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out, x + x + y + y)
420*da0073e9SAndroid Build Coastguard Worker
421*da0073e9SAndroid Build Coastguard Worker    def test_in_dims_wrong_type_err_msg(self):
422*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3)
423*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(3)
424*da0073e9SAndroid Build Coastguard Worker        msg = r"expected `in_dims` to be int or a \(potentially nested\) tuple"
425*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, msg):
426*da0073e9SAndroid Build Coastguard Worker            vmap(torch.mul, [0, 0])(x, y)
427*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, msg):
428*da0073e9SAndroid Build Coastguard Worker            vmap(torch.mul, set({0}))(x, y)
429*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, msg):
430*da0073e9SAndroid Build Coastguard Worker            vmap(torch.mul, "lol")(x, y)
431*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, msg):
432*da0073e9SAndroid Build Coastguard Worker            vmap(lambda z: z[0] + z[1], in_dims=[0, 0])([x, y])
433*da0073e9SAndroid Build Coastguard Worker        # The following should not throw
434*da0073e9SAndroid Build Coastguard Worker        vmap(torch.mul, (0, 0))(x, y)
435*da0073e9SAndroid Build Coastguard Worker
436*da0073e9SAndroid Build Coastguard Worker    def test_not_enough_in_dims_err_msg(self):
437*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3)
438*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(3)
439*da0073e9SAndroid Build Coastguard Worker        msg = r"in_dims is not compatible with the structure of `inputs`"
440*da0073e9SAndroid Build Coastguard Worker
441*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, msg):
442*da0073e9SAndroid Build Coastguard Worker            vmap(torch.mul, (0,))(x, y)
443*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, msg):
444*da0073e9SAndroid Build Coastguard Worker            vmap(torch.mul, (0, 0, 0))(x, y)
445*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, msg):
446*da0073e9SAndroid Build Coastguard Worker            vmap(lambda z: z[0] + z[1], in_dims=([0],))([x, y])
447*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, msg):
448*da0073e9SAndroid Build Coastguard Worker            vmap(lambda z: z[0] + z[1], in_dims=((0, 0),))([x, y])
449*da0073e9SAndroid Build Coastguard Worker        # The following should not throw
450*da0073e9SAndroid Build Coastguard Worker        vmap(torch.mul, (0, 0))(x, y)
451*da0073e9SAndroid Build Coastguard Worker
452*da0073e9SAndroid Build Coastguard Worker    def test_integer_in_dim_but_not_tensor_input_err_msg(self):
453*da0073e9SAndroid Build Coastguard Worker        def foo(xy):
454*da0073e9SAndroid Build Coastguard Worker            return xy[0] * xy[1]
455*da0073e9SAndroid Build Coastguard Worker
456*da0073e9SAndroid Build Coastguard Worker        def bar(x, yz):
457*da0073e9SAndroid Build Coastguard Worker            return x * yz[0] * yz[1]
458*da0073e9SAndroid Build Coastguard Worker
459*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 3)
460*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(2, 3)
461*da0073e9SAndroid Build Coastguard Worker
462*da0073e9SAndroid Build Coastguard Worker        # the following are errors in jax (and will always be errors)
463*da0073e9SAndroid Build Coastguard Worker        msg = "Got in_dim=0 for an input but the input is of type"
464*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, msg):
465*da0073e9SAndroid Build Coastguard Worker            vmap(torch.sum)(x, 0)
466*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, msg):
467*da0073e9SAndroid Build Coastguard Worker            vmap(torch.sum, (0, 0))(x, 0)
468*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, msg):
469*da0073e9SAndroid Build Coastguard Worker            vmap(lambda z: z[0] + z[1], in_dims=([0, 0],))([x, 1])
470*da0073e9SAndroid Build Coastguard Worker        # The following should not throw
471*da0073e9SAndroid Build Coastguard Worker        vmap(torch.sum, (0, None))(x, 0)
472*da0073e9SAndroid Build Coastguard Worker
473*da0073e9SAndroid Build Coastguard Worker    def test_in_dim_not_in_tensor_err_msg(self):
474*da0073e9SAndroid Build Coastguard Worker        def foo(x):
475*da0073e9SAndroid Build Coastguard Worker            return x * x
476*da0073e9SAndroid Build Coastguard Worker
477*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 3)
478*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(2, 3)
479*da0073e9SAndroid Build Coastguard Worker
480*da0073e9SAndroid Build Coastguard Worker        msg = r"Got in_dim=-?\w for some input, but that input is a Tensor of dimensionality \w"
481*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, msg):
482*da0073e9SAndroid Build Coastguard Worker            vmap(foo)(torch.randn([]))
483*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, msg):
484*da0073e9SAndroid Build Coastguard Worker            vmap(foo, in_dims=(0,))(torch.randn([]))
485*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, msg):
486*da0073e9SAndroid Build Coastguard Worker            vmap(foo, in_dims=(-1,))(x)
487*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, msg):
488*da0073e9SAndroid Build Coastguard Worker            vmap(foo, in_dims=(2,))(y)
489*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, msg):
490*da0073e9SAndroid Build Coastguard Worker            vmap(lambda z: z[0] + z[1], in_dims=([3, 0],))([x, y])
491*da0073e9SAndroid Build Coastguard Worker        # the following should not throw
492*da0073e9SAndroid Build Coastguard Worker        vmap(foo, in_dims=(0,))(torch.randn(2, 3))
493*da0073e9SAndroid Build Coastguard Worker        vmap(foo, in_dims=(1,))(torch.randn(2, 3))
494*da0073e9SAndroid Build Coastguard Worker
495*da0073e9SAndroid Build Coastguard Worker    def test_fallback_does_not_warn_by_default(self):
496*da0073e9SAndroid Build Coastguard Worker        # NB: One day we will implement a batching rule for torch.atan2.
497*da0073e9SAndroid Build Coastguard Worker        # If/when we do, this test should be replaced to test the fallback
498*da0073e9SAndroid Build Coastguard Worker        # path on another operator to avoid bitrot.
499*da0073e9SAndroid Build Coastguard Worker        op = torch.atan2
500*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(11)
501*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(11)
502*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as wa:
503*da0073e9SAndroid Build Coastguard Worker            result = vmap(op)(x, y)
504*da0073e9SAndroid Build Coastguard Worker            # The single warning here is the "vmap is experimental"
505*da0073e9SAndroid Build Coastguard Worker            # warning, not a warning from the vmap fallback path.
506*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(wa), 1)
507*da0073e9SAndroid Build Coastguard Worker
508*da0073e9SAndroid Build Coastguard Worker    def test_fallback_warns_when_warnings_are_enabled(self):
509*da0073e9SAndroid Build Coastguard Worker        # NB: One day we will implement a batching rule for torch.atan2.
510*da0073e9SAndroid Build Coastguard Worker        # If/when we do, this test should be replaced to test the fallback
511*da0073e9SAndroid Build Coastguard Worker        # path on another operator to avoid bitrot.
512*da0073e9SAndroid Build Coastguard Worker        op = torch.atan2
513*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(11)
514*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(11)
515*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as wa:
516*da0073e9SAndroid Build Coastguard Worker            with EnableVmapFallbackWarnings():
517*da0073e9SAndroid Build Coastguard Worker                result = vmap(op)(x, y)
518*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(wa), 2)
519*da0073e9SAndroid Build Coastguard Worker            self.assertRegex(str(wa[-1].message), FALLBACK_REGEX)
520*da0073e9SAndroid Build Coastguard Worker
521*da0073e9SAndroid Build Coastguard Worker    def _assert_uses_vmap_fallback(self, vmap_args, inputs):
522*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as wa:
523*da0073e9SAndroid Build Coastguard Worker            with EnableVmapFallbackWarnings():
524*da0073e9SAndroid Build Coastguard Worker                result = vmap(*vmap_args)(*inputs)
525*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(wa), 2)
526*da0073e9SAndroid Build Coastguard Worker            self.assertRegex(str(wa[-1].message), FALLBACK_REGEX)
527*da0073e9SAndroid Build Coastguard Worker
528*da0073e9SAndroid Build Coastguard Worker    def test_fallback_zero_dim(self):
529*da0073e9SAndroid Build Coastguard Worker        # NB: One day we will implement a batching rule for torch.atan2.
530*da0073e9SAndroid Build Coastguard Worker        # If/when we do, this test should be replaced to test the fallback
531*da0073e9SAndroid Build Coastguard Worker        # path on another operator to avoid bitrot.
532*da0073e9SAndroid Build Coastguard Worker        op = torch.atan2
533*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(11)
534*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(11)
535*da0073e9SAndroid Build Coastguard Worker        self._assert_uses_vmap_fallback((op,), (x, y))
536*da0073e9SAndroid Build Coastguard Worker
537*da0073e9SAndroid Build Coastguard Worker        B0, B1 = 0, 3
538*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(B0, 11)
539*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(11)
540*da0073e9SAndroid Build Coastguard Worker
541*da0073e9SAndroid Build Coastguard Worker        msg = "The fallback path does not support vmap over dims of size 0"
542*da0073e9SAndroid Build Coastguard Worker
543*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, msg):
544*da0073e9SAndroid Build Coastguard Worker            vmap(op, (0, None))(x, y)
545*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, msg):
546*da0073e9SAndroid Build Coastguard Worker            vmap(op, (None, 0))(y, x)
547*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, msg):
548*da0073e9SAndroid Build Coastguard Worker            vmap(op)(x, x)
549*da0073e9SAndroid Build Coastguard Worker
550*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(B0, B1, 11)
551*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(B1, 11)
552*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, msg):
553*da0073e9SAndroid Build Coastguard Worker            vmap(op, (0, None))(x, y)
554*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, msg):
555*da0073e9SAndroid Build Coastguard Worker            vmap(op, (None, 0))(y, x)
556*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, msg):
557*da0073e9SAndroid Build Coastguard Worker            vmap(op)(x, x)
558*da0073e9SAndroid Build Coastguard Worker
559*da0073e9SAndroid Build Coastguard Worker    def test_fallback_atan2(self):
560*da0073e9SAndroid Build Coastguard Worker        # NB: One day we will implement a batching rule for torch.atan2.
561*da0073e9SAndroid Build Coastguard Worker        # If/when we do, this test should be replaced to test the fallback
562*da0073e9SAndroid Build Coastguard Worker        # path on another operator to avoid bitrot.
563*da0073e9SAndroid Build Coastguard Worker        op = torch.atan2
564*da0073e9SAndroid Build Coastguard Worker
565*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(5, 7, 11)
566*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(5, 7, 11)
567*da0073e9SAndroid Build Coastguard Worker
568*da0073e9SAndroid Build Coastguard Worker        self._assert_uses_vmap_fallback((op,), (x, y))
569*da0073e9SAndroid Build Coastguard Worker
570*da0073e9SAndroid Build Coastguard Worker        # fallback on torch.atan2
571*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(7, 11, 5)
572*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(5, 7, 11)
573*da0073e9SAndroid Build Coastguard Worker        result = vmap(op, (2, 0))(x, y)
574*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result, op(x.permute(2, 0, 1), y))
575*da0073e9SAndroid Build Coastguard Worker
576*da0073e9SAndroid Build Coastguard Worker        # fallback on torch.atan2, nested vmap
577*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(7, 11, 5)
578*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(5, 7, 11)
579*da0073e9SAndroid Build Coastguard Worker        result = vmap(vmap(op), (2, 0))(x, y)
580*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result, op(x.permute(2, 0, 1), y))
581*da0073e9SAndroid Build Coastguard Worker
582*da0073e9SAndroid Build Coastguard Worker        # big batch size (total 10000)
583*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(100, 10, 10, 5)
584*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(100, 10, 10)
585*da0073e9SAndroid Build Coastguard Worker        result = vmap(vmap(vmap(op)))(x, y)
586*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result, op(x, y.view(100, 10, 10, 1)))
587*da0073e9SAndroid Build Coastguard Worker
588*da0073e9SAndroid Build Coastguard Worker    def test_fallback_masked_fill(self):
589*da0073e9SAndroid Build Coastguard Worker        # NB: One day we will implement a batching rule for masked_fill
590*da0073e9SAndroid Build Coastguard Worker        # If/when we do, this test should be replaced to test the fallback
591*da0073e9SAndroid Build Coastguard Worker        # path on another operator to avoid bitrot.
592*da0073e9SAndroid Build Coastguard Worker        def run_test(batch_size):
593*da0073e9SAndroid Build Coastguard Worker            B0 = batch_size
594*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(B0, 7, 11, 13)
595*da0073e9SAndroid Build Coastguard Worker            dim = 0
596*da0073e9SAndroid Build Coastguard Worker            index = torch.tensor([0, 4, 2])
597*da0073e9SAndroid Build Coastguard Worker            values = torch.randn(B0, 3, 11, 13)
598*da0073e9SAndroid Build Coastguard Worker
599*da0073e9SAndroid Build Coastguard Worker            self._assert_uses_vmap_fallback(
600*da0073e9SAndroid Build Coastguard Worker                (torch.index_add, (0, None, None, 0)), (x, dim, index, values)
601*da0073e9SAndroid Build Coastguard Worker            )
602*da0073e9SAndroid Build Coastguard Worker
603*da0073e9SAndroid Build Coastguard Worker            result = vmap(torch.index_add, (0, None, None, 0))(x, dim, index, values)
604*da0073e9SAndroid Build Coastguard Worker            expected = torch.index_add(x, dim + 1, index, values.view(B0, 3, 11, 13))
605*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result, expected)
606*da0073e9SAndroid Build Coastguard Worker
607*da0073e9SAndroid Build Coastguard Worker        run_test(batch_size=5)
608*da0073e9SAndroid Build Coastguard Worker        run_test(batch_size=1237)
609*da0073e9SAndroid Build Coastguard Worker
610*da0073e9SAndroid Build Coastguard Worker    def test_fallback_multiple_returns(self):
611*da0073e9SAndroid Build Coastguard Worker        # NB: One day we will implement a batching rule for torch.var_mean
612*da0073e9SAndroid Build Coastguard Worker        # If/when we do, this test should be replaced to test the fallback
613*da0073e9SAndroid Build Coastguard Worker        # path on another operator to avoid bitrot.
614*da0073e9SAndroid Build Coastguard Worker        B0, B1, B2 = 2, 3, 1237
615*da0073e9SAndroid Build Coastguard Worker        tensor = torch.randn(B0, 10)
616*da0073e9SAndroid Build Coastguard Worker
617*da0073e9SAndroid Build Coastguard Worker        self._assert_uses_vmap_fallback((torch.var_mean,), (tensor,))
618*da0073e9SAndroid Build Coastguard Worker
619*da0073e9SAndroid Build Coastguard Worker        # fallback correctness on torch.var_mean
620*da0073e9SAndroid Build Coastguard Worker        result = vmap(torch.var_mean)(tensor)
621*da0073e9SAndroid Build Coastguard Worker        expected = torch.var_mean(tensor, dim=1)
622*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result, expected)
623*da0073e9SAndroid Build Coastguard Worker
624*da0073e9SAndroid Build Coastguard Worker        # nested vmap
625*da0073e9SAndroid Build Coastguard Worker        tensor = torch.randn(B0, B1, 10)
626*da0073e9SAndroid Build Coastguard Worker        result = vmap(vmap(torch.var_mean))(tensor)
627*da0073e9SAndroid Build Coastguard Worker        expected = torch.var_mean(tensor, dim=2)
628*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result, expected)
629*da0073e9SAndroid Build Coastguard Worker
630*da0073e9SAndroid Build Coastguard Worker        # big batch size, nested vmap
631*da0073e9SAndroid Build Coastguard Worker        tensor = torch.randn(B0, B1, B2, 10)
632*da0073e9SAndroid Build Coastguard Worker        result = vmap(vmap(vmap(torch.var_mean)))(tensor)
633*da0073e9SAndroid Build Coastguard Worker        expected = torch.var_mean(tensor, dim=3)
634*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result, expected)
635*da0073e9SAndroid Build Coastguard Worker
636*da0073e9SAndroid Build Coastguard Worker    def test_inplace_fallback_unary(self):
637*da0073e9SAndroid Build Coastguard Worker        # Test the in-place fallback on an in-place method that takes no
638*da0073e9SAndroid Build Coastguard Worker        # additional Tensor arguments. This is the simplest case of the fallback.
639*da0073e9SAndroid Build Coastguard Worker        # NB: One day we will implement a batching rule for acos_.
640*da0073e9SAndroid Build Coastguard Worker        # If/when we do, this test should be replaced to test the fallback
641*da0073e9SAndroid Build Coastguard Worker        # path on another operator to avoid bitrot.
642*da0073e9SAndroid Build Coastguard Worker        op = Tensor.acos_
643*da0073e9SAndroid Build Coastguard Worker        B0, B1, B2 = 2, 3, 10000
644*da0073e9SAndroid Build Coastguard Worker
645*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(B0, 5)
646*da0073e9SAndroid Build Coastguard Worker        self._assert_uses_vmap_fallback((op,), (x,))
647*da0073e9SAndroid Build Coastguard Worker
648*da0073e9SAndroid Build Coastguard Worker        # Single vmap
649*da0073e9SAndroid Build Coastguard Worker        x_orig = torch.rand(B0, 5)
650*da0073e9SAndroid Build Coastguard Worker        x = x_orig.clone()
651*da0073e9SAndroid Build Coastguard Worker        result = vmap(op)(x)
652*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(result is x)
653*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result, x_orig.acos())
654*da0073e9SAndroid Build Coastguard Worker
655*da0073e9SAndroid Build Coastguard Worker        # Single vmap + different out_dim produces a view(!)
656*da0073e9SAndroid Build Coastguard Worker        x_orig = torch.rand(B0, 5)
657*da0073e9SAndroid Build Coastguard Worker        x = x_orig.clone()
658*da0073e9SAndroid Build Coastguard Worker        result = vmap(op, out_dims=(1,))(x)
659*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(result._base is x)
660*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result, x_orig.t().acos())
661*da0073e9SAndroid Build Coastguard Worker
662*da0073e9SAndroid Build Coastguard Worker        # Nested vmap
663*da0073e9SAndroid Build Coastguard Worker        x_orig = torch.randn(B0, B1, 5)
664*da0073e9SAndroid Build Coastguard Worker        x = x_orig.clone()
665*da0073e9SAndroid Build Coastguard Worker        result = vmap(vmap(op))(x)
666*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(result is x)
667*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result, x_orig.acos())
668*da0073e9SAndroid Build Coastguard Worker
669*da0073e9SAndroid Build Coastguard Worker        # Nested vmap, large batch size
670*da0073e9SAndroid Build Coastguard Worker        x_orig = torch.randn(B0, B1, B2, 5)
671*da0073e9SAndroid Build Coastguard Worker        x = x_orig.clone()
672*da0073e9SAndroid Build Coastguard Worker        result = vmap(vmap(vmap(op)))(x)
673*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(result is x)
674*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result, x_orig.acos())
675*da0073e9SAndroid Build Coastguard Worker
676*da0073e9SAndroid Build Coastguard Worker    def test_inplace_fallback_nary_same_levels(self):
677*da0073e9SAndroid Build Coastguard Worker        # NB: One day we will implement a batching rule for atan2_
678*da0073e9SAndroid Build Coastguard Worker        # If/when we do, this test should be replaced to test the fallback
679*da0073e9SAndroid Build Coastguard Worker        # path on another operator to avoid bitrot.
680*da0073e9SAndroid Build Coastguard Worker        op = Tensor.atan2_
681*da0073e9SAndroid Build Coastguard Worker        outplace_op = torch.atan2
682*da0073e9SAndroid Build Coastguard Worker
683*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(5, 7, 11)
684*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(5, 7, 11)
685*da0073e9SAndroid Build Coastguard Worker        self._assert_uses_vmap_fallback((op,), (x, y))
686*da0073e9SAndroid Build Coastguard Worker
687*da0073e9SAndroid Build Coastguard Worker        # Single vmap
688*da0073e9SAndroid Build Coastguard Worker        B0 = 5
689*da0073e9SAndroid Build Coastguard Worker        x_orig = torch.randn(7, 11, B0)
690*da0073e9SAndroid Build Coastguard Worker        x = x_orig.clone()
691*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(B0, 7, 11)
692*da0073e9SAndroid Build Coastguard Worker        vmap(op, (2, 0))(x, y)
693*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x, outplace_op(x_orig, y.movedim(0, 2)))
694*da0073e9SAndroid Build Coastguard Worker
695*da0073e9SAndroid Build Coastguard Worker        # Nested vmap
696*da0073e9SAndroid Build Coastguard Worker        B0, B1 = 5, 7
697*da0073e9SAndroid Build Coastguard Worker        x_orig = torch.randn(B1, 11, B0)
698*da0073e9SAndroid Build Coastguard Worker        x = x_orig.clone()
699*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(B0, B1, 11)
700*da0073e9SAndroid Build Coastguard Worker        vmap(vmap(op), (2, 0))(x, y)
701*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x, outplace_op(x_orig, y.movedim([0, 1], [2, 0])))
702*da0073e9SAndroid Build Coastguard Worker
703*da0073e9SAndroid Build Coastguard Worker        # big batch size (total 10000)
704*da0073e9SAndroid Build Coastguard Worker        B0, B1, B2 = 100, 10, 10
705*da0073e9SAndroid Build Coastguard Worker        x_orig = torch.randn(B0, B1, B2, 5)
706*da0073e9SAndroid Build Coastguard Worker        x = x_orig.clone()
707*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(B0, B1, B2)
708*da0073e9SAndroid Build Coastguard Worker        result = vmap(vmap(vmap(op)))(x, y)
709*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x, outplace_op(x_orig, y.view(B0, B1, B2, 1)))
710*da0073e9SAndroid Build Coastguard Worker
711*da0073e9SAndroid Build Coastguard Worker    def test_inplace_fallback_nary_different_levels(self):
712*da0073e9SAndroid Build Coastguard Worker        # NB: One day we will implement a batching rule for atan2_
713*da0073e9SAndroid Build Coastguard Worker        # If/when we do, this test should be replaced to test the fallback
714*da0073e9SAndroid Build Coastguard Worker        # path on another operator to avoid bitrot.
715*da0073e9SAndroid Build Coastguard Worker        op = Tensor.atan2_
716*da0073e9SAndroid Build Coastguard Worker        outplace_op = torch.atan2
717*da0073e9SAndroid Build Coastguard Worker        B0, B1, B2 = 2, 3, 5
718*da0073e9SAndroid Build Coastguard Worker
719*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(B0, 7)
720*da0073e9SAndroid Build Coastguard Worker        y = torch.rand(7)
721*da0073e9SAndroid Build Coastguard Worker        self._assert_uses_vmap_fallback((op, (0, None)), (x, y))
722*da0073e9SAndroid Build Coastguard Worker
723*da0073e9SAndroid Build Coastguard Worker        # op(left, right): All of the levels in right are found in left
724*da0073e9SAndroid Build Coastguard Worker        x_orig = torch.rand(B0, 7)
725*da0073e9SAndroid Build Coastguard Worker        x = x_orig.clone()
726*da0073e9SAndroid Build Coastguard Worker        y = torch.rand(7)
727*da0073e9SAndroid Build Coastguard Worker        vmap(op, in_dims=(0, None))(x, y)
728*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x, outplace_op(x_orig, y))
729*da0073e9SAndroid Build Coastguard Worker
730*da0073e9SAndroid Build Coastguard Worker        x_orig = torch.rand(B0, B1, 7)
731*da0073e9SAndroid Build Coastguard Worker        x = x_orig.clone()
732*da0073e9SAndroid Build Coastguard Worker        y = torch.rand(B0, 7)
733*da0073e9SAndroid Build Coastguard Worker        vmap(vmap(op, in_dims=(0, None)))(x, y)
734*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x, outplace_op(x_orig, y.view(B0, 1, 7)))
735*da0073e9SAndroid Build Coastguard Worker
736*da0073e9SAndroid Build Coastguard Worker        # op(left, right): Some of the levels in right are not found in left
737*da0073e9SAndroid Build Coastguard Worker        msg = r"vmap: aten::atan2_\(self, \*extra_args\) is not possible"
738*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(7)
739*da0073e9SAndroid Build Coastguard Worker        y = torch.rand(B0, 7)
740*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, msg):
741*da0073e9SAndroid Build Coastguard Worker            vmap(op, in_dims=(None, 0))(x, y)
742*da0073e9SAndroid Build Coastguard Worker
743*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(B1, 7)
744*da0073e9SAndroid Build Coastguard Worker        y = torch.rand(B0, 7)
745*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, msg):
746*da0073e9SAndroid Build Coastguard Worker            vmap(vmap(op, in_dims=(0, None)), in_dims=(None, 0))(x, y)
747*da0073e9SAndroid Build Coastguard Worker
748*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(B1, 7)
749*da0073e9SAndroid Build Coastguard Worker        y = torch.rand(7, B0)
750*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, msg):
751*da0073e9SAndroid Build Coastguard Worker            vmap(vmap(op, in_dims=(0, None)), in_dims=(None, 1))(x, y)
752*da0073e9SAndroid Build Coastguard Worker
753*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(B0, 7)
754*da0073e9SAndroid Build Coastguard Worker        y = torch.rand(B0, B1, 7)
755*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, msg):
756*da0073e9SAndroid Build Coastguard Worker            vmap(vmap(op, in_dims=(None, 0)))(x, y)
757*da0073e9SAndroid Build Coastguard Worker
758*da0073e9SAndroid Build Coastguard Worker    def test_backward_unsupported_interaction(self):
759*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3, requires_grad=True)
760*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(5)
761*da0073e9SAndroid Build Coastguard Worker        grad = torch.randn_like(x)
762*da0073e9SAndroid Build Coastguard Worker        err_msg = r"backward\(\) called inside torch.vmap"
763*da0073e9SAndroid Build Coastguard Worker
764*da0073e9SAndroid Build Coastguard Worker        def backward_on_vmapped_tensor(x):
765*da0073e9SAndroid Build Coastguard Worker            x.sum().backward()
766*da0073e9SAndroid Build Coastguard Worker
767*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, err_msg):
768*da0073e9SAndroid Build Coastguard Worker            vmap(backward_on_vmapped_tensor)(x)
769*da0073e9SAndroid Build Coastguard Worker
770*da0073e9SAndroid Build Coastguard Worker        def backward_with_vmapped_grad(x, grad):
771*da0073e9SAndroid Build Coastguard Worker            x.backward(grad)
772*da0073e9SAndroid Build Coastguard Worker
773*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, err_msg):
774*da0073e9SAndroid Build Coastguard Worker            vmap(backward_with_vmapped_grad)(x, grad)
775*da0073e9SAndroid Build Coastguard Worker
776*da0073e9SAndroid Build Coastguard Worker        def completely_unrelated_backward(y):
777*da0073e9SAndroid Build Coastguard Worker            x.sum().backward()
778*da0073e9SAndroid Build Coastguard Worker
779*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, err_msg):
780*da0073e9SAndroid Build Coastguard Worker            vmap(completely_unrelated_backward)(y)
781*da0073e9SAndroid Build Coastguard Worker
782*da0073e9SAndroid Build Coastguard Worker    def test_grad_unsupported_interaction(self):
783*da0073e9SAndroid Build Coastguard Worker        input_tensor = torch.randn(3, requires_grad=True)
784*da0073e9SAndroid Build Coastguard Worker        err_msg = "autograd.grad.* called inside torch.vmap"
785*da0073e9SAndroid Build Coastguard Worker
786*da0073e9SAndroid Build Coastguard Worker        captured = torch.randn(3, requires_grad=True)
787*da0073e9SAndroid Build Coastguard Worker
788*da0073e9SAndroid Build Coastguard Worker        def output_to_grad_is_vmapped(input_tensor):
789*da0073e9SAndroid Build Coastguard Worker            output = (captured * input_tensor).sum()
790*da0073e9SAndroid Build Coastguard Worker            return torch.autograd.grad([output], [captured])[0]
791*da0073e9SAndroid Build Coastguard Worker
792*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, err_msg):
793*da0073e9SAndroid Build Coastguard Worker            vmap(output_to_grad_is_vmapped)(input_tensor)
794*da0073e9SAndroid Build Coastguard Worker
795*da0073e9SAndroid Build Coastguard Worker        output = (input_tensor**2).sum()
796*da0073e9SAndroid Build Coastguard Worker
797*da0073e9SAndroid Build Coastguard Worker        def input_to_grad_is_vmapped(input_tensor):
798*da0073e9SAndroid Build Coastguard Worker            return torch.autograd.grad([output], [input_tensor])[0]
799*da0073e9SAndroid Build Coastguard Worker
800*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, err_msg):
801*da0073e9SAndroid Build Coastguard Worker            vmap(input_to_grad_is_vmapped)(input_tensor)
802*da0073e9SAndroid Build Coastguard Worker
803*da0073e9SAndroid Build Coastguard Worker    def test_batched_gradient_basic(self):
804*da0073e9SAndroid Build Coastguard Worker        N = 3
805*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(N, requires_grad=True)
806*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(N)
807*da0073e9SAndroid Build Coastguard Worker
808*da0073e9SAndroid Build Coastguard Worker        def vjp_mul(v):
809*da0073e9SAndroid Build Coastguard Worker            return torch.autograd.grad([x * y], [x], grad_outputs=[v])[0]
810*da0073e9SAndroid Build Coastguard Worker
811*da0073e9SAndroid Build Coastguard Worker        batched_v = torch.eye(N)
812*da0073e9SAndroid Build Coastguard Worker        jacobian = vmap(vjp_mul)(batched_v)
813*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(jacobian, torch.diagflat(y))
814*da0073e9SAndroid Build Coastguard Worker
815*da0073e9SAndroid Build Coastguard Worker    def test_functools_partial(self):
816*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3)
817*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(2, 3)
818*da0073e9SAndroid Build Coastguard Worker        result = vmap(functools.partial(torch.mul, x))(y)
819*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result, x * y)
820*da0073e9SAndroid Build Coastguard Worker
821*da0073e9SAndroid Build Coastguard Worker    def test_nn_module(self):
822*da0073e9SAndroid Build Coastguard Worker        tensor = torch.randn(2, 3)
823*da0073e9SAndroid Build Coastguard Worker        model = torch.nn.Linear(3, 3, bias=False)
824*da0073e9SAndroid Build Coastguard Worker        result = vmap(model)(tensor)
825*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result, model(tensor))
826*da0073e9SAndroid Build Coastguard Worker
827*da0073e9SAndroid Build Coastguard Worker    def test_fallback_with_undefined_grad(self):
828*da0073e9SAndroid Build Coastguard Worker        B0 = 7
829*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 3, 4, 5, requires_grad=True)
830*da0073e9SAndroid Build Coastguard Worker        weight = torch.randn(3, 3, 1, 1)
831*da0073e9SAndroid Build Coastguard Worker        v = torch.randn(B0, 2, 3, 4, 5)
832*da0073e9SAndroid Build Coastguard Worker
833*da0073e9SAndroid Build Coastguard Worker        def get_vjp(v):
834*da0073e9SAndroid Build Coastguard Worker            result = torch.nn.functional.conv2d(x, weight)
835*da0073e9SAndroid Build Coastguard Worker            (grad_x,) = torch.autograd.grad(result, x, v)
836*da0073e9SAndroid Build Coastguard Worker            return grad_x
837*da0073e9SAndroid Build Coastguard Worker
838*da0073e9SAndroid Build Coastguard Worker        # Runs vmap(get_vjp)(v), which should not error out.
839*da0073e9SAndroid Build Coastguard Worker        # The backward formula for convolution returns an undefined
840*da0073e9SAndroid Build Coastguard Worker        # Tensor for grad_bias because the original bias does not exist.
841*da0073e9SAndroid Build Coastguard Worker        #
842*da0073e9SAndroid Build Coastguard Worker        # In the future we'll probably add a batching rule for convolution
843*da0073e9SAndroid Build Coastguard Worker        # backward. When this happens, we should modify this test to use a
844*da0073e9SAndroid Build Coastguard Worker        # different op (and/or create and use a dummy operator) to avoid bitrot.
845*da0073e9SAndroid Build Coastguard Worker        self._assert_uses_vmap_fallback([get_vjp], [v])
846*da0073e9SAndroid Build Coastguard Worker
847*da0073e9SAndroid Build Coastguard Worker
848*da0073e9SAndroid Build Coastguard Workerdef slice_inputs(inputs, bdims, i):
849*da0073e9SAndroid Build Coastguard Worker    result = []
850*da0073e9SAndroid Build Coastguard Worker    for inp, bdim in zip(inputs, bdims):
851*da0073e9SAndroid Build Coastguard Worker        if bdim is None:
852*da0073e9SAndroid Build Coastguard Worker            result.append(inp)
853*da0073e9SAndroid Build Coastguard Worker        else:
854*da0073e9SAndroid Build Coastguard Worker            result.append(inp.select(bdim, i))
855*da0073e9SAndroid Build Coastguard Worker    return tuple(result)
856*da0073e9SAndroid Build Coastguard Worker
857*da0073e9SAndroid Build Coastguard Worker
858*da0073e9SAndroid Build Coastguard Workerdef reference_vmap(op, inputs, in_dims=0, out_dims=0):
859*da0073e9SAndroid Build Coastguard Worker    if isinstance(in_dims, int):
860*da0073e9SAndroid Build Coastguard Worker        in_dims = (in_dims,) * len(inputs)
861*da0073e9SAndroid Build Coastguard Worker    bdim_sizes = [inp.size(dim) for inp, dim in zip(inputs, in_dims) if dim is not None]
862*da0073e9SAndroid Build Coastguard Worker    assert all(bdim_size == bdim_sizes[0] for bdim_size in bdim_sizes)
863*da0073e9SAndroid Build Coastguard Worker    bdim_size = bdim_sizes[0]
864*da0073e9SAndroid Build Coastguard Worker    results = tuple(op(*slice_inputs(inputs, in_dims, i)) for i in range(bdim_size))
865*da0073e9SAndroid Build Coastguard Worker
866*da0073e9SAndroid Build Coastguard Worker    assert len(results) > 0
867*da0073e9SAndroid Build Coastguard Worker    op_has_single_return = not isinstance(results[0], tuple)
868*da0073e9SAndroid Build Coastguard Worker    if op_has_single_return:
869*da0073e9SAndroid Build Coastguard Worker        assert all(isinstance(result, torch.Tensor) for result in results)
870*da0073e9SAndroid Build Coastguard Worker        if isinstance(out_dims, int):
871*da0073e9SAndroid Build Coastguard Worker            out_dims = (out_dims,) * 1
872*da0073e9SAndroid Build Coastguard Worker        return torch.stack(results, dim=out_dims[0])
873*da0073e9SAndroid Build Coastguard Worker
874*da0073e9SAndroid Build Coastguard Worker    assert all(isinstance(result, tuple) for result in results)
875*da0073e9SAndroid Build Coastguard Worker    num_returns = len(results[0])
876*da0073e9SAndroid Build Coastguard Worker    assert all(len(result) == num_returns for result in results)
877*da0073e9SAndroid Build Coastguard Worker    if isinstance(out_dims, int):
878*da0073e9SAndroid Build Coastguard Worker        out_dims = (out_dims,) * num_returns
879*da0073e9SAndroid Build Coastguard Worker    return tuple(
880*da0073e9SAndroid Build Coastguard Worker        torch.stack(result_shards, out_dim)
881*da0073e9SAndroid Build Coastguard Worker        for result_shards, out_dim in zip(zip(*results), out_dims)
882*da0073e9SAndroid Build Coastguard Worker    )
883*da0073e9SAndroid Build Coastguard Worker
884*da0073e9SAndroid Build Coastguard Worker
885*da0073e9SAndroid Build Coastguard Workerclass TensorFactory:
886*da0073e9SAndroid Build Coastguard Worker    @staticmethod
887*da0073e9SAndroid Build Coastguard Worker    def rand(size, device="cpu", dtype=torch.float):
888*da0073e9SAndroid Build Coastguard Worker        return torch.rand(size, device=device, dtype=dtype)
889*da0073e9SAndroid Build Coastguard Worker
890*da0073e9SAndroid Build Coastguard Worker    @staticmethod
891*da0073e9SAndroid Build Coastguard Worker    def randn(size, device="cpu", dtype=torch.float):
892*da0073e9SAndroid Build Coastguard Worker        return torch.randn(size, device=device, dtype=dtype)
893*da0073e9SAndroid Build Coastguard Worker
894*da0073e9SAndroid Build Coastguard Worker    @staticmethod
895*da0073e9SAndroid Build Coastguard Worker    def randp1(size, device="cpu", dtype=torch.float):
896*da0073e9SAndroid Build Coastguard Worker        return torch.rand(size, device=device, dtype=dtype) + 1
897*da0073e9SAndroid Build Coastguard Worker
898*da0073e9SAndroid Build Coastguard Worker
899*da0073e9SAndroid Build Coastguard Worker# Tests vmap(op, in_dims, out_dims)(*inputs) by comparing the output to a
900*da0073e9SAndroid Build Coastguard Worker# (slow) sequential map+stack fallback.
901*da0073e9SAndroid Build Coastguard Worker#
902*da0073e9SAndroid Build Coastguard Worker# check_view: Test if the first returned output is a view of the first input
903*da0073e9SAndroid Build Coastguard Worker# check_propagates_grad: Test if the operation propagates gradients.
904*da0073e9SAndroid Build Coastguard Workerdef _vmap_test(
905*da0073e9SAndroid Build Coastguard Worker    self,
906*da0073e9SAndroid Build Coastguard Worker    op,
907*da0073e9SAndroid Build Coastguard Worker    inputs,
908*da0073e9SAndroid Build Coastguard Worker    in_dims=0,
909*da0073e9SAndroid Build Coastguard Worker    out_dims=0,
910*da0073e9SAndroid Build Coastguard Worker    check_view=False,
911*da0073e9SAndroid Build Coastguard Worker    check_propagates_grad=True,
912*da0073e9SAndroid Build Coastguard Worker):
913*da0073e9SAndroid Build Coastguard Worker    result = vmap(op, in_dims, out_dims)(*inputs)
914*da0073e9SAndroid Build Coastguard Worker    reference_result = reference_vmap(op, inputs, in_dims, out_dims)
915*da0073e9SAndroid Build Coastguard Worker    self.assertEqual(result, reference_result)
916*da0073e9SAndroid Build Coastguard Worker    op_has_single_return = not isinstance(result, tuple)
917*da0073e9SAndroid Build Coastguard Worker
918*da0073e9SAndroid Build Coastguard Worker    if check_view:
919*da0073e9SAndroid Build Coastguard Worker        result_as_tuple = (result,) if op_has_single_return else result
920*da0073e9SAndroid Build Coastguard Worker        for output in result_as_tuple:
921*da0073e9SAndroid Build Coastguard Worker            input0_base = inputs[0] if inputs[0]._base is None else inputs[0]._base
922*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(
923*da0073e9SAndroid Build Coastguard Worker                output._base is input0_base,
924*da0073e9SAndroid Build Coastguard Worker                msg="result was not a view of the first input!",
925*da0073e9SAndroid Build Coastguard Worker            )
926*da0073e9SAndroid Build Coastguard Worker
927*da0073e9SAndroid Build Coastguard Worker    if not check_propagates_grad:
928*da0073e9SAndroid Build Coastguard Worker        return
929*da0073e9SAndroid Build Coastguard Worker    # Assuming input[0] is a floating-point tensor. Check if the vmap
930*da0073e9SAndroid Build Coastguard Worker    # operation propagates the requires_grad flag to the zeroth output.
931*da0073e9SAndroid Build Coastguard Worker    # Some vmap operators are implemented in a way that assumes that
932*da0073e9SAndroid Build Coastguard Worker    # they are composite with respect to autograd. If the operator ever is
933*da0073e9SAndroid Build Coastguard Worker    # changed to not be composite with respect to autograd, then the
934*da0073e9SAndroid Build Coastguard Worker    # following check should fail.
935*da0073e9SAndroid Build Coastguard Worker    inputs_clone = list(inputs)
936*da0073e9SAndroid Build Coastguard Worker    inputs_clone[0] = inputs[0].clone().requires_grad_()
937*da0073e9SAndroid Build Coastguard Worker    result = vmap(op, in_dims, out_dims)(*inputs_clone)
938*da0073e9SAndroid Build Coastguard Worker    result_as_tuple = (result,) if op_has_single_return else result
939*da0073e9SAndroid Build Coastguard Worker    self.assertTrue(result[0].requires_grad)
940*da0073e9SAndroid Build Coastguard Worker
941*da0073e9SAndroid Build Coastguard Worker
942*da0073e9SAndroid Build Coastguard Workerdef should_allow_vmap_fallback_usage(fn):
943*da0073e9SAndroid Build Coastguard Worker    return getattr(fn, "_allow_vmap_fallback_usage", False)
944*da0073e9SAndroid Build Coastguard Worker
945*da0073e9SAndroid Build Coastguard Worker
946*da0073e9SAndroid Build Coastguard Workerdef allowVmapFallbackUsage(fn):
947*da0073e9SAndroid Build Coastguard Worker    fn._allow_vmap_fallback_usage = True
948*da0073e9SAndroid Build Coastguard Worker    return fn
949*da0073e9SAndroid Build Coastguard Worker
950*da0073e9SAndroid Build Coastguard Worker
951*da0073e9SAndroid Build Coastguard Worker# All tests of TestVmapBaseLegacy check that the slow vmap fallback is never invoked.
952*da0073e9SAndroid Build Coastguard Worker# This is so that we can incrementally add batching rules for operators to
953*da0073e9SAndroid Build Coastguard Worker# replace the slow vmap fallback path for said operators. To skip this check,
954*da0073e9SAndroid Build Coastguard Worker# please use the allowVmapFallbackUsage decorator.
955*da0073e9SAndroid Build Coastguard Worker#
956*da0073e9SAndroid Build Coastguard Worker# NB: Don't add tests to TestVmapBaseLegacy directly, unless you want them to run
957*da0073e9SAndroid Build Coastguard Worker# on every subclass of TestVmapBaseLegacy. Add them to e.g. TestVmapOperators.
958*da0073e9SAndroid Build Coastguard Worker#
959*da0073e9SAndroid Build Coastguard Worker# NB: TestVmapBaseLegacy is a nested class. This prevents test runners from picking
960*da0073e9SAndroid Build Coastguard Worker# it up and running it.
961*da0073e9SAndroid Build Coastguard Workerclass Namespace:
962*da0073e9SAndroid Build Coastguard Worker    class TestVmapBaseLegacy(TestCase):
963*da0073e9SAndroid Build Coastguard Worker        def __init__(self, method_name="runTest"):
964*da0073e9SAndroid Build Coastguard Worker            super().__init__(method_name)
965*da0073e9SAndroid Build Coastguard Worker
966*da0073e9SAndroid Build Coastguard Worker            test_method = getattr(self, method_name, None)
967*da0073e9SAndroid Build Coastguard Worker            if test_method is None:
968*da0073e9SAndroid Build Coastguard Worker                return
969*da0073e9SAndroid Build Coastguard Worker
970*da0073e9SAndroid Build Coastguard Worker            if not should_allow_vmap_fallback_usage(test_method):
971*da0073e9SAndroid Build Coastguard Worker                setattr(
972*da0073e9SAndroid Build Coastguard Worker                    self,
973*da0073e9SAndroid Build Coastguard Worker                    method_name,
974*da0073e9SAndroid Build Coastguard Worker                    self._wrap_method_with_vmap_fallback_check(test_method),
975*da0073e9SAndroid Build Coastguard Worker                )
976*da0073e9SAndroid Build Coastguard Worker
977*da0073e9SAndroid Build Coastguard Worker        def _wrap_method_with_vmap_fallback_check(self, method):
978*da0073e9SAndroid Build Coastguard Worker            msg = (
979*da0073e9SAndroid Build Coastguard Worker                "Expected the test to not invoke the vmap fallback path, i.e., "
980*da0073e9SAndroid Build Coastguard Worker                "all of the operators being tested in this test should have batching "
981*da0073e9SAndroid Build Coastguard Worker                "rules implemented. If you are intentionally testing something to "
982*da0073e9SAndroid Build Coastguard Worker                "do with the fallback path, use allowVmapFallbackUsage. Otherwise, "
983*da0073e9SAndroid Build Coastguard Worker                "please make sure that batching rules are implemented for the "
984*da0073e9SAndroid Build Coastguard Worker                "operator(s) being tested."
985*da0073e9SAndroid Build Coastguard Worker            )
986*da0073e9SAndroid Build Coastguard Worker
987*da0073e9SAndroid Build Coastguard Worker            @functools.wraps(method)
988*da0073e9SAndroid Build Coastguard Worker            def wrapper(self, *args, **kwargs):
989*da0073e9SAndroid Build Coastguard Worker                with warnings.catch_warnings(record=True) as wa:
990*da0073e9SAndroid Build Coastguard Worker                    warnings.simplefilter("always")
991*da0073e9SAndroid Build Coastguard Worker                    with EnableVmapFallbackWarnings():
992*da0073e9SAndroid Build Coastguard Worker                        method(*args, **kwargs)
993*da0073e9SAndroid Build Coastguard Worker                    for captured_warning in wa:
994*da0073e9SAndroid Build Coastguard Worker                        self.assertNotRegex(
995*da0073e9SAndroid Build Coastguard Worker                            str(captured_warning.message), FALLBACK_REGEX, msg
996*da0073e9SAndroid Build Coastguard Worker                        )
997*da0073e9SAndroid Build Coastguard Worker
998*da0073e9SAndroid Build Coastguard Worker            return types.MethodType(wrapper, self)
999*da0073e9SAndroid Build Coastguard Worker
1000*da0073e9SAndroid Build Coastguard Worker        @allowVmapFallbackUsage
1001*da0073e9SAndroid Build Coastguard Worker        def test_vmap_fallback_check_ok(self):
1002*da0073e9SAndroid Build Coastguard Worker            # One day we'll implement a batching rule for torch.var_mean.
1003*da0073e9SAndroid Build Coastguard Worker            # When that happens, please change the example to use an
1004*da0073e9SAndroid Build Coastguard Worker            # operator that doesn't have a batching rule implemented.
1005*da0073e9SAndroid Build Coastguard Worker            op_using_fallback = torch.var_mean
1006*da0073e9SAndroid Build Coastguard Worker            vmap(op_using_fallback)(torch.rand(3))
1007*da0073e9SAndroid Build Coastguard Worker
1008*da0073e9SAndroid Build Coastguard Worker        def test_vmap_fallback_check(self):
1009*da0073e9SAndroid Build Coastguard Worker            @self._wrap_method_with_vmap_fallback_check
1010*da0073e9SAndroid Build Coastguard Worker            def no_fallback(self):
1011*da0073e9SAndroid Build Coastguard Worker                pass
1012*da0073e9SAndroid Build Coastguard Worker
1013*da0073e9SAndroid Build Coastguard Worker            # One day we'll implement a batching rule for torch.var_mean.
1014*da0073e9SAndroid Build Coastguard Worker            # When that happens, please change the example to use an
1015*da0073e9SAndroid Build Coastguard Worker            # operator that doesn't have a batching rule implemented.
1016*da0073e9SAndroid Build Coastguard Worker            op_using_fallback = torch.var_mean
1017*da0073e9SAndroid Build Coastguard Worker
1018*da0073e9SAndroid Build Coastguard Worker            @self._wrap_method_with_vmap_fallback_check
1019*da0073e9SAndroid Build Coastguard Worker            def uses_fallback(self):
1020*da0073e9SAndroid Build Coastguard Worker                vmap(op_using_fallback)(torch.rand(3))
1021*da0073e9SAndroid Build Coastguard Worker
1022*da0073e9SAndroid Build Coastguard Worker            no_fallback(self)
1023*da0073e9SAndroid Build Coastguard Worker
1024*da0073e9SAndroid Build Coastguard Worker            with self.assertRaises(AssertionError):
1025*da0073e9SAndroid Build Coastguard Worker                uses_fallback(self)
1026*da0073e9SAndroid Build Coastguard Worker
1027*da0073e9SAndroid Build Coastguard Worker
1028*da0073e9SAndroid Build Coastguard Workerclass TestVmapOperatorsLegacy(Namespace.TestVmapBaseLegacy):
1029*da0073e9SAndroid Build Coastguard Worker    def _vmap_test(self, *args, **kwargs):
1030*da0073e9SAndroid Build Coastguard Worker        return _vmap_test(self, *args, **kwargs)
1031*da0073e9SAndroid Build Coastguard Worker
1032*da0073e9SAndroid Build Coastguard Worker    def _vmap_view_test(self, *args, **kwargs):
1033*da0073e9SAndroid Build Coastguard Worker        self._vmap_test(*args, **kwargs, check_view=True)
1034*da0073e9SAndroid Build Coastguard Worker
1035*da0073e9SAndroid Build Coastguard Worker    def _test_unary(self, op, getter, device, *args, **kwargs):
1036*da0073e9SAndroid Build Coastguard Worker        test = functools.partial(self._vmap_test, *args, **kwargs)
1037*da0073e9SAndroid Build Coastguard Worker        B0, B1 = 7, 11
1038*da0073e9SAndroid Build Coastguard Worker
1039*da0073e9SAndroid Build Coastguard Worker        # Single vmap, various in_dims / out_dims
1040*da0073e9SAndroid Build Coastguard Worker        test(op, [getter([B0, 3], device)])
1041*da0073e9SAndroid Build Coastguard Worker        test(op, [getter([2, 5, B0, 3], device)], in_dims=2)
1042*da0073e9SAndroid Build Coastguard Worker        test(op, [getter([2, 5, B0, 3], device)], in_dims=2, out_dims=2)
1043*da0073e9SAndroid Build Coastguard Worker
1044*da0073e9SAndroid Build Coastguard Worker        # Doubly nested vmap
1045*da0073e9SAndroid Build Coastguard Worker        test(vmap(op), [getter([B0, B1], device)])
1046*da0073e9SAndroid Build Coastguard Worker        test(vmap(op), [getter([B1, 2, 5, B0, 3], device)], in_dims=2)
1047*da0073e9SAndroid Build Coastguard Worker        test(
1048*da0073e9SAndroid Build Coastguard Worker            vmap(op, in_dims=2),
1049*da0073e9SAndroid Build Coastguard Worker            [getter([2, 5, B0, B1, 3], device)],
1050*da0073e9SAndroid Build Coastguard Worker            in_dims=2,
1051*da0073e9SAndroid Build Coastguard Worker            out_dims=2,
1052*da0073e9SAndroid Build Coastguard Worker        )
1053*da0073e9SAndroid Build Coastguard Worker
1054*da0073e9SAndroid Build Coastguard Worker    def test_unary_pointwise_ops(self):
1055*da0073e9SAndroid Build Coastguard Worker        cases = [
1056*da0073e9SAndroid Build Coastguard Worker            (torch.abs, TensorFactory.randn),
1057*da0073e9SAndroid Build Coastguard Worker            (torch.acos, TensorFactory.rand),
1058*da0073e9SAndroid Build Coastguard Worker            (torch.asin, TensorFactory.rand),
1059*da0073e9SAndroid Build Coastguard Worker            (torch.atan, TensorFactory.rand),
1060*da0073e9SAndroid Build Coastguard Worker            (torch.ceil, TensorFactory.randn),
1061*da0073e9SAndroid Build Coastguard Worker            (torch.cos, TensorFactory.rand),
1062*da0073e9SAndroid Build Coastguard Worker            (torch.cosh, TensorFactory.rand),
1063*da0073e9SAndroid Build Coastguard Worker            (torch.digamma, TensorFactory.rand),
1064*da0073e9SAndroid Build Coastguard Worker            (torch.exp, TensorFactory.randn),
1065*da0073e9SAndroid Build Coastguard Worker            (torch.expm1, TensorFactory.randn),
1066*da0073e9SAndroid Build Coastguard Worker            (torch.floor, TensorFactory.randn),
1067*da0073e9SAndroid Build Coastguard Worker            (torch.frac, TensorFactory.randn),
1068*da0073e9SAndroid Build Coastguard Worker            (torch.lgamma, TensorFactory.rand),
1069*da0073e9SAndroid Build Coastguard Worker            (torch.log, TensorFactory.randp1),
1070*da0073e9SAndroid Build Coastguard Worker            (torch.log10, TensorFactory.randp1),
1071*da0073e9SAndroid Build Coastguard Worker            (torch.log1p, TensorFactory.randp1),
1072*da0073e9SAndroid Build Coastguard Worker            (torch.log2, TensorFactory.randp1),
1073*da0073e9SAndroid Build Coastguard Worker            (torch.neg, TensorFactory.randn),
1074*da0073e9SAndroid Build Coastguard Worker            (torch.reciprocal, TensorFactory.randp1),
1075*da0073e9SAndroid Build Coastguard Worker            (torch.relu, TensorFactory.randn),
1076*da0073e9SAndroid Build Coastguard Worker            (torch.round, TensorFactory.randn),
1077*da0073e9SAndroid Build Coastguard Worker            (torch.rsqrt, TensorFactory.randp1),
1078*da0073e9SAndroid Build Coastguard Worker            (torch.sigmoid, TensorFactory.randn),
1079*da0073e9SAndroid Build Coastguard Worker            (torch.sign, TensorFactory.randn),
1080*da0073e9SAndroid Build Coastguard Worker            (torch.sin, TensorFactory.rand),
1081*da0073e9SAndroid Build Coastguard Worker            (torch.sinh, TensorFactory.rand),
1082*da0073e9SAndroid Build Coastguard Worker            (torch.sqrt, TensorFactory.rand),
1083*da0073e9SAndroid Build Coastguard Worker            (torch.tan, TensorFactory.rand),
1084*da0073e9SAndroid Build Coastguard Worker            (torch.tanh, TensorFactory.rand),
1085*da0073e9SAndroid Build Coastguard Worker            (torch.trunc, TensorFactory.randn),
1086*da0073e9SAndroid Build Coastguard Worker        ]
1087*da0073e9SAndroid Build Coastguard Worker        for op, getter in cases:
1088*da0073e9SAndroid Build Coastguard Worker            self._test_unary(op, getter, "cpu")
1089*da0073e9SAndroid Build Coastguard Worker
1090*da0073e9SAndroid Build Coastguard Worker    def test_clone(self):
1091*da0073e9SAndroid Build Coastguard Worker        # Some basic tests
1092*da0073e9SAndroid Build Coastguard Worker        self._test_unary(lambda x: x.clone(), TensorFactory.randn, "cpu")
1093*da0073e9SAndroid Build Coastguard Worker        self._test_unary(
1094*da0073e9SAndroid Build Coastguard Worker            lambda x: x.clone(memory_format=torch.preserve_format),
1095*da0073e9SAndroid Build Coastguard Worker            TensorFactory.randn,
1096*da0073e9SAndroid Build Coastguard Worker            "cpu",
1097*da0073e9SAndroid Build Coastguard Worker        )
1098*da0073e9SAndroid Build Coastguard Worker        self._test_unary(
1099*da0073e9SAndroid Build Coastguard Worker            lambda x: x.clone(memory_format=torch.contiguous_format),
1100*da0073e9SAndroid Build Coastguard Worker            TensorFactory.randn,
1101*da0073e9SAndroid Build Coastguard Worker            "cpu",
1102*da0073e9SAndroid Build Coastguard Worker        )
1103*da0073e9SAndroid Build Coastguard Worker
1104*da0073e9SAndroid Build Coastguard Worker        # Test that the per-examples are contiguous when using torch.contiguous_format
1105*da0073e9SAndroid Build Coastguard Worker        def clone_contiguous(x):
1106*da0073e9SAndroid Build Coastguard Worker            return x.clone(memory_format=torch.contiguous_format)
1107*da0073e9SAndroid Build Coastguard Worker
1108*da0073e9SAndroid Build Coastguard Worker        B0, B1 = 3, 5
1109*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, B0, 7)
1110*da0073e9SAndroid Build Coastguard Worker        y = vmap(clone_contiguous, in_dims=1, out_dims=1)(x)
1111*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(y.movedim(1, 0).is_contiguous())
1112*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(y[:, 0, :].is_contiguous())
1113*da0073e9SAndroid Build Coastguard Worker
1114*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, B0, 7, B1)
1115*da0073e9SAndroid Build Coastguard Worker        y = vmap(vmap(clone_contiguous, in_dims=2), in_dims=1)(x)
1116*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(y.is_contiguous())
1117*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(y[0][0].is_contiguous())
1118*da0073e9SAndroid Build Coastguard Worker
1119*da0073e9SAndroid Build Coastguard Worker        msg = r"only supported with memory_format torch.preserve_format or torch.contiguous_format"
1120*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, msg):
1121*da0073e9SAndroid Build Coastguard Worker            vmap(lambda x: x.clone(memory_format=torch.channels_last))(torch.randn(B0))
1122*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, msg):
1123*da0073e9SAndroid Build Coastguard Worker            vmap(lambda x: x.clone(memory_format=torch.channels_last_3d))(
1124*da0073e9SAndroid Build Coastguard Worker                torch.randn(B0)
1125*da0073e9SAndroid Build Coastguard Worker            )
1126*da0073e9SAndroid Build Coastguard Worker
1127*da0073e9SAndroid Build Coastguard Worker    def test_binary_pointwise_ops(self):
1128*da0073e9SAndroid Build Coastguard Worker        def get_number(getter):
1129*da0073e9SAndroid Build Coastguard Worker            return getter([]).item()
1130*da0073e9SAndroid Build Coastguard Worker
1131*da0073e9SAndroid Build Coastguard Worker        def make_case(op, input_getter=TensorFactory.randn):
1132*da0073e9SAndroid Build Coastguard Worker            return (op, input_getter)
1133*da0073e9SAndroid Build Coastguard Worker
1134*da0073e9SAndroid Build Coastguard Worker        cases = [
1135*da0073e9SAndroid Build Coastguard Worker            # Basic arithmetic
1136*da0073e9SAndroid Build Coastguard Worker            make_case(torch.add),
1137*da0073e9SAndroid Build Coastguard Worker            make_case(lambda x, y: x + y),
1138*da0073e9SAndroid Build Coastguard Worker            make_case(torch.sub),
1139*da0073e9SAndroid Build Coastguard Worker            make_case(lambda x, y: x - y),
1140*da0073e9SAndroid Build Coastguard Worker            make_case(torch.mul),
1141*da0073e9SAndroid Build Coastguard Worker            make_case(lambda x, y: x * y),
1142*da0073e9SAndroid Build Coastguard Worker            make_case(torch.div, input_getter=TensorFactory.randp1),
1143*da0073e9SAndroid Build Coastguard Worker            make_case(lambda x, y: x / y, input_getter=TensorFactory.randp1),
1144*da0073e9SAndroid Build Coastguard Worker            make_case(torch.pow, input_getter=TensorFactory.randp1),
1145*da0073e9SAndroid Build Coastguard Worker            make_case(lambda x, y: x**y, input_getter=TensorFactory.randp1),
1146*da0073e9SAndroid Build Coastguard Worker        ]
1147*da0073e9SAndroid Build Coastguard Worker        test = self._vmap_test
1148*da0073e9SAndroid Build Coastguard Worker
1149*da0073e9SAndroid Build Coastguard Worker        for op, getter in cases:
1150*da0073e9SAndroid Build Coastguard Worker            device = "cpu"
1151*da0073e9SAndroid Build Coastguard Worker            B0, B1 = 7, 11
1152*da0073e9SAndroid Build Coastguard Worker
1153*da0073e9SAndroid Build Coastguard Worker            # Single vmap: op(Tensor, Tensor)
1154*da0073e9SAndroid Build Coastguard Worker            test(op, (getter([B0, 3], device), getter([B0, 3], device)))
1155*da0073e9SAndroid Build Coastguard Worker            test(op, (getter([B0], device), getter([B0, 2, 3], device)))
1156*da0073e9SAndroid Build Coastguard Worker            test(op, (getter([B0], device), getter([2, B0, 3], device)), in_dims=(0, 1))
1157*da0073e9SAndroid Build Coastguard Worker            test(
1158*da0073e9SAndroid Build Coastguard Worker                op,
1159*da0073e9SAndroid Build Coastguard Worker                (getter([B0], device), getter([2, B0, 3], device)),
1160*da0073e9SAndroid Build Coastguard Worker                in_dims=(0, 1),
1161*da0073e9SAndroid Build Coastguard Worker                out_dims=1,
1162*da0073e9SAndroid Build Coastguard Worker            )
1163*da0073e9SAndroid Build Coastguard Worker            test(op, (getter([B0], device), getter([2, 3], device)), in_dims=(0, None))
1164*da0073e9SAndroid Build Coastguard Worker            test(
1165*da0073e9SAndroid Build Coastguard Worker                op, (getter([2, 3], device), getter([B0, 3], device)), in_dims=(0, None)
1166*da0073e9SAndroid Build Coastguard Worker            )
1167*da0073e9SAndroid Build Coastguard Worker
1168*da0073e9SAndroid Build Coastguard Worker            # Nested vmap: op(Tensor, Tensor)
1169*da0073e9SAndroid Build Coastguard Worker            test(
1170*da0073e9SAndroid Build Coastguard Worker                vmap(op), (getter([B0, B1, 2, 3], device), getter([B0, B1, 3], device))
1171*da0073e9SAndroid Build Coastguard Worker            )
1172*da0073e9SAndroid Build Coastguard Worker            test(
1173*da0073e9SAndroid Build Coastguard Worker                vmap(op, in_dims=(None, 0)),
1174*da0073e9SAndroid Build Coastguard Worker                (getter([B0, 2, 3], device), getter([B1, 3], device)),
1175*da0073e9SAndroid Build Coastguard Worker                in_dims=(0, None),
1176*da0073e9SAndroid Build Coastguard Worker            )
1177*da0073e9SAndroid Build Coastguard Worker
1178*da0073e9SAndroid Build Coastguard Worker            # Python number overload: op(Tensor, Number) (and vice-versa)
1179*da0073e9SAndroid Build Coastguard Worker            number = get_number(getter)
1180*da0073e9SAndroid Build Coastguard Worker            self._test_unary(lambda t: op(t, number), getter, device)
1181*da0073e9SAndroid Build Coastguard Worker            number = get_number(getter)
1182*da0073e9SAndroid Build Coastguard Worker            self._test_unary(lambda t: op(number, t), getter, device)
1183*da0073e9SAndroid Build Coastguard Worker
1184*da0073e9SAndroid Build Coastguard Worker            # Type promotion: op(Logical Scalar Tensor, Logical Scalar Tensor)
1185*da0073e9SAndroid Build Coastguard Worker            test(op, (getter([B0], device), getter([B0], device, dtype=torch.double)))
1186*da0073e9SAndroid Build Coastguard Worker            test(op, (getter([B0], device, dtype=torch.double), getter([B0], device)))
1187*da0073e9SAndroid Build Coastguard Worker            test(op, (getter([B0], device), getter([B0], device)))
1188*da0073e9SAndroid Build Coastguard Worker
1189*da0073e9SAndroid Build Coastguard Worker            # Type promotion: op(Tensor, Logical Scalar Tensor) (and vice-versa)
1190*da0073e9SAndroid Build Coastguard Worker            test(op, (getter([B0, 2], device), getter([B0], device, torch.double)))
1191*da0073e9SAndroid Build Coastguard Worker            test(op, (getter([B0], device, torch.double), getter([B0, 2], device)))
1192*da0073e9SAndroid Build Coastguard Worker
1193*da0073e9SAndroid Build Coastguard Worker            if not torch.cuda.is_available():
1194*da0073e9SAndroid Build Coastguard Worker                continue
1195*da0073e9SAndroid Build Coastguard Worker
1196*da0073e9SAndroid Build Coastguard Worker            # TODO(rzou): fix the following
1197*da0073e9SAndroid Build Coastguard Worker            # # Test cross-device scalars
1198*da0073e9SAndroid Build Coastguard Worker            # number = get_number(getter)
1199*da0073e9SAndroid Build Coastguard Worker            # self._test_unary(lambda t: op(t, number), getter, device='cuda')
1200*da0073e9SAndroid Build Coastguard Worker            # self._test_unary(lambda t: op(number, t), getter, device='cuda')
1201*da0073e9SAndroid Build Coastguard Worker            # self._test_unary(lambda t: op(t, torch.tensor(number)), getter, device='cuda')
1202*da0073e9SAndroid Build Coastguard Worker
1203*da0073e9SAndroid Build Coastguard Worker    def test_as_strided(self):
1204*da0073e9SAndroid Build Coastguard Worker        def _test(sizes, strides, offset, tensor, lambd):
1205*da0073e9SAndroid Build Coastguard Worker            result = vmap(lambda t: t.as_strided(sizes, strides, offset))(tensor)
1206*da0073e9SAndroid Build Coastguard Worker            expected = vmap(lambd)(tensor)
1207*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(result._base is expected._base)
1208*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result, expected)
1209*da0073e9SAndroid Build Coastguard Worker
1210*da0073e9SAndroid Build Coastguard Worker        # single vmap test
1211*da0073e9SAndroid Build Coastguard Worker        B0 = 5
1212*da0073e9SAndroid Build Coastguard Worker        tensors = [
1213*da0073e9SAndroid Build Coastguard Worker            # contiguous
1214*da0073e9SAndroid Build Coastguard Worker            torch.randn(B0, 2, 3),
1215*da0073e9SAndroid Build Coastguard Worker            # non-contiguous
1216*da0073e9SAndroid Build Coastguard Worker            torch.randn(B0, 3, 2).transpose(1, 2),
1217*da0073e9SAndroid Build Coastguard Worker            # non-zero storage offset
1218*da0073e9SAndroid Build Coastguard Worker            torch.randn(2, B0, 2, 3)[1],
1219*da0073e9SAndroid Build Coastguard Worker            # non-contiguous strides, zero storage offset
1220*da0073e9SAndroid Build Coastguard Worker            torch.randn(B0, 2, 4, 3, 7)[:, :, 0, :, 0],
1221*da0073e9SAndroid Build Coastguard Worker            # non-contiguous strides, non-zero storage offset
1222*da0073e9SAndroid Build Coastguard Worker            torch.randn(B0, 2, 4, 3, 7)[:, :, 2, :, 1],
1223*da0073e9SAndroid Build Coastguard Worker        ]
1224*da0073e9SAndroid Build Coastguard Worker
1225*da0073e9SAndroid Build Coastguard Worker        for x in tensors:
1226*da0073e9SAndroid Build Coastguard Worker            S0, S1 = x.stride()[1:]
1227*da0073e9SAndroid Build Coastguard Worker            offset = x.storage_offset()
1228*da0073e9SAndroid Build Coastguard Worker
1229*da0073e9SAndroid Build Coastguard Worker            # Broadcast
1230*da0073e9SAndroid Build Coastguard Worker            _test(
1231*da0073e9SAndroid Build Coastguard Worker                [5, 5, 2, 3], [0, 0, S0, S1], offset, x, lambda x: x.expand(5, 5, 2, 3)
1232*da0073e9SAndroid Build Coastguard Worker            )
1233*da0073e9SAndroid Build Coastguard Worker            # transpose
1234*da0073e9SAndroid Build Coastguard Worker            _test([3, 2], [S1, S0], offset, x, lambda x: x.transpose(0, 1))
1235*da0073e9SAndroid Build Coastguard Worker            # select
1236*da0073e9SAndroid Build Coastguard Worker            _test([2], [S0], offset + S1, x, lambda x: x[:, 1])
1237*da0073e9SAndroid Build Coastguard Worker
1238*da0073e9SAndroid Build Coastguard Worker        # Nested vmap test
1239*da0073e9SAndroid Build Coastguard Worker        B1 = 7
1240*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(B1, B0, 2, 3)
1241*da0073e9SAndroid Build Coastguard Worker        S0, S1 = x.stride()[2:]
1242*da0073e9SAndroid Build Coastguard Worker        result = vmap(
1243*da0073e9SAndroid Build Coastguard Worker            vmap(lambda t: t.as_strided([5, 5, 2, 3], [0, 0, S0, S1])), in_dims=1
1244*da0073e9SAndroid Build Coastguard Worker        )(x)
1245*da0073e9SAndroid Build Coastguard Worker        expected = vmap(vmap(lambda t: t.expand(5, 5, 2, 3)), in_dims=1)(x)
1246*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(result._base is expected._base)
1247*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result, expected)
1248*da0073e9SAndroid Build Coastguard Worker
1249*da0073e9SAndroid Build Coastguard Worker        # Check that mal-formatted size/strides doesn't crash
1250*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1251*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "size and stride must have the same length"
1252*da0073e9SAndroid Build Coastguard Worker        ):
1253*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(B0, 2, 3).transpose(0, 1)
1254*da0073e9SAndroid Build Coastguard Worker            vmap(lambda x: x.as_strided([1, 1, 1], [1, 1]))(x)
1255*da0073e9SAndroid Build Coastguard Worker
1256*da0073e9SAndroid Build Coastguard Worker        # Sanity check #1: we require the batch dims to be at the front of the
1257*da0073e9SAndroid Build Coastguard Worker        # tensor (in memory layout).
1258*da0073e9SAndroid Build Coastguard Worker        msg = "batch dims being vmapped over are at the front of the tensor"
1259*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, msg):
1260*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(2, B0, 3).transpose(0, 1)
1261*da0073e9SAndroid Build Coastguard Worker            vmap(lambda x: x.as_strided([2, 3], [B0 * 3, 1]))(x)
1262*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, msg):
1263*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(B0, 2, 3, B1).movedim(3, 1)
1264*da0073e9SAndroid Build Coastguard Worker            vmap(vmap(lambda x: x.as_strided([2, 3], [B1 * 3, B1])))(x)
1265*da0073e9SAndroid Build Coastguard Worker
1266*da0073e9SAndroid Build Coastguard Worker        # All the Sanity check #2{a,b,c} cases check that
1267*da0073e9SAndroid Build Coastguard Worker        # xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset())
1268*da0073e9SAndroid Build Coastguard Worker        # doesn't index memory that is out of bounds of xs[i]. This condition
1269*da0073e9SAndroid Build Coastguard Worker        # is important to the correctness of the as_strided batching rule
1270*da0073e9SAndroid Build Coastguard Worker        # (see NOTE: [When will the as_strided_batching_rule fail?])
1271*da0073e9SAndroid Build Coastguard Worker
1272*da0073e9SAndroid Build Coastguard Worker        # Sanity check #2a: The maximum indexable location of
1273*da0073e9SAndroid Build Coastguard Worker        # xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset())
1274*da0073e9SAndroid Build Coastguard Worker        # is less than or equal to the maximum indexable location of xs[i].
1275*da0073e9SAndroid Build Coastguard Worker        msg = "This is not supported inside of vmap"
1276*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, msg):
1277*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(B0, 3)
1278*da0073e9SAndroid Build Coastguard Worker            vmap(lambda x: x.as_strided([3], [1], 1))(x)
1279*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, msg):
1280*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(B0, 3, 5)
1281*da0073e9SAndroid Build Coastguard Worker            vmap(lambda x: x.as_strided([4, 4], [4, 1], 0))(x)
1282*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, msg):
1283*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(B0, B1, 3, 5)
1284*da0073e9SAndroid Build Coastguard Worker            vmap(vmap(lambda x: x.as_strided([4, 4], [4, 1], 0)))(x)
1285*da0073e9SAndroid Build Coastguard Worker
1286*da0073e9SAndroid Build Coastguard Worker        # Sanity check #2b: The min indexable location of
1287*da0073e9SAndroid Build Coastguard Worker        # xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset())
1288*da0073e9SAndroid Build Coastguard Worker        # is greater than or equal to the min indexable location of xs[i].
1289*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, msg):
1290*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(2, B0, 3)[1]
1291*da0073e9SAndroid Build Coastguard Worker            vmap(lambda x: x.as_strided([3], [1], B0 * 3 - 1))(x)
1292*da0073e9SAndroid Build Coastguard Worker
1293*da0073e9SAndroid Build Coastguard Worker        # Sanity check #2c:
1294*da0073e9SAndroid Build Coastguard Worker        # xs[i] is a zero-dim tensor, but
1295*da0073e9SAndroid Build Coastguard Worker        # xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset())
1296*da0073e9SAndroid Build Coastguard Worker        # is not
1297*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, msg):
1298*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(B0, 0, 3)
1299*da0073e9SAndroid Build Coastguard Worker            vmap(lambda x: x.as_strided([3], [1]))(x)
1300*da0073e9SAndroid Build Coastguard Worker
1301*da0073e9SAndroid Build Coastguard Worker    def test_bmm(self):
1302*da0073e9SAndroid Build Coastguard Worker        op = torch.bmm
1303*da0073e9SAndroid Build Coastguard Worker        test = self._vmap_test
1304*da0073e9SAndroid Build Coastguard Worker        B0, B1 = 7, 11
1305*da0073e9SAndroid Build Coastguard Worker
1306*da0073e9SAndroid Build Coastguard Worker        # shape mismatch
1307*da0073e9SAndroid Build Coastguard Worker        msg = "Shape mismatch"
1308*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, msg):
1309*da0073e9SAndroid Build Coastguard Worker            vmap(op)(torch.randn(B0, 2, 2, 2), torch.randn(B0, 2))
1310*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, msg):
1311*da0073e9SAndroid Build Coastguard Worker            vmap(op, in_dims=(0, None))(torch.randn(B0, 3, 3, 2), torch.randn(2, 2))
1312*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, msg):
1313*da0073e9SAndroid Build Coastguard Worker            vmap(op, in_dims=(None, 0))(torch.randn(2, 2), torch.randn(B0, 2, 2, 2))
1314*da0073e9SAndroid Build Coastguard Worker
1315*da0073e9SAndroid Build Coastguard Worker        # left arg is vmapped
1316*da0073e9SAndroid Build Coastguard Worker        test(op, (torch.rand(B0, 2, 3, 5), torch.rand(2, 5, 3)), in_dims=(0, None))
1317*da0073e9SAndroid Build Coastguard Worker        test(
1318*da0073e9SAndroid Build Coastguard Worker            vmap(op, in_dims=(0, None)),
1319*da0073e9SAndroid Build Coastguard Worker            (torch.rand(B1, B0, 2, 3, 5), torch.rand(2, 5, 3)),
1320*da0073e9SAndroid Build Coastguard Worker            in_dims=(1, None),
1321*da0073e9SAndroid Build Coastguard Worker        )
1322*da0073e9SAndroid Build Coastguard Worker
1323*da0073e9SAndroid Build Coastguard Worker        # right arg is vmapped
1324*da0073e9SAndroid Build Coastguard Worker        test(op, (torch.rand(2, 5, 3), torch.rand(B0, 2, 3, 5)), in_dims=(None, 0))
1325*da0073e9SAndroid Build Coastguard Worker        test(
1326*da0073e9SAndroid Build Coastguard Worker            vmap(op, in_dims=(None, 0)),
1327*da0073e9SAndroid Build Coastguard Worker            (torch.rand(2, 5, 3), torch.rand(B1, B0, 2, 3, 5)),
1328*da0073e9SAndroid Build Coastguard Worker            in_dims=(None, 1),
1329*da0073e9SAndroid Build Coastguard Worker        )
1330*da0073e9SAndroid Build Coastguard Worker
1331*da0073e9SAndroid Build Coastguard Worker        # both args are vmapped
1332*da0073e9SAndroid Build Coastguard Worker        test(op, (torch.rand(B0, 2, 3, 5), torch.rand(B0, 2, 5, 3)))
1333*da0073e9SAndroid Build Coastguard Worker        test(
1334*da0073e9SAndroid Build Coastguard Worker            vmap(op),
1335*da0073e9SAndroid Build Coastguard Worker            (torch.rand(B1, B0, 2, 3, 5), torch.rand(B0, B1, 2, 5, 3)),
1336*da0073e9SAndroid Build Coastguard Worker            in_dims=(1, 0),
1337*da0073e9SAndroid Build Coastguard Worker        )
1338*da0073e9SAndroid Build Coastguard Worker        test(
1339*da0073e9SAndroid Build Coastguard Worker            vmap(op, in_dims=(0, None)),
1340*da0073e9SAndroid Build Coastguard Worker            (torch.rand(B1, 2, 3, 5), torch.rand(B0, 2, 5, 3)),
1341*da0073e9SAndroid Build Coastguard Worker            in_dims=(None, 0),
1342*da0073e9SAndroid Build Coastguard Worker        )
1343*da0073e9SAndroid Build Coastguard Worker
1344*da0073e9SAndroid Build Coastguard Worker    def test_cat(self):
1345*da0073e9SAndroid Build Coastguard Worker        test = self._vmap_test
1346*da0073e9SAndroid Build Coastguard Worker        B0, B1 = 5, 7
1347*da0073e9SAndroid Build Coastguard Worker
1348*da0073e9SAndroid Build Coastguard Worker        # Quick hack b/c vmap can't accept a list of tensors as an argument
1349*da0073e9SAndroid Build Coastguard Worker        def get_op(dim):
1350*da0073e9SAndroid Build Coastguard Worker            def op(*tensors):
1351*da0073e9SAndroid Build Coastguard Worker                return torch.cat(tensors, dim=dim)
1352*da0073e9SAndroid Build Coastguard Worker
1353*da0073e9SAndroid Build Coastguard Worker            return op
1354*da0073e9SAndroid Build Coastguard Worker
1355*da0073e9SAndroid Build Coastguard Worker        test(get_op(0), (torch.rand(B0, 2), torch.rand(B0, 3)))
1356*da0073e9SAndroid Build Coastguard Worker        test(get_op(0), (torch.rand(2), torch.rand(B0, 3)), in_dims=(None, 0))
1357*da0073e9SAndroid Build Coastguard Worker        test(get_op(0), (torch.rand(2, 17), torch.rand(3, 17, B0)), in_dims=(None, 2))
1358*da0073e9SAndroid Build Coastguard Worker        test(get_op(-1), (torch.rand(17, 2), torch.rand(17, 3, B0)), in_dims=(None, 2))
1359*da0073e9SAndroid Build Coastguard Worker        test(
1360*da0073e9SAndroid Build Coastguard Worker            vmap(get_op(0), in_dims=(0, None)),
1361*da0073e9SAndroid Build Coastguard Worker            (torch.rand(B1, 2), torch.rand(B0, 3)),
1362*da0073e9SAndroid Build Coastguard Worker            in_dims=(None, 0),
1363*da0073e9SAndroid Build Coastguard Worker        )
1364*da0073e9SAndroid Build Coastguard Worker        test(
1365*da0073e9SAndroid Build Coastguard Worker            vmap(get_op(0), in_dims=(0, 0)),
1366*da0073e9SAndroid Build Coastguard Worker            (torch.rand(B1, 2), torch.rand(B0, B1, 3)),
1367*da0073e9SAndroid Build Coastguard Worker            in_dims=(None, 0),
1368*da0073e9SAndroid Build Coastguard Worker        )
1369*da0073e9SAndroid Build Coastguard Worker
1370*da0073e9SAndroid Build Coastguard Worker    def test_conj(self):
1371*da0073e9SAndroid Build Coastguard Worker        op = torch.conj
1372*da0073e9SAndroid Build Coastguard Worker
1373*da0073e9SAndroid Build Coastguard Worker        def run_test(dtype):
1374*da0073e9SAndroid Build Coastguard Worker            def get(shape):
1375*da0073e9SAndroid Build Coastguard Worker                return torch.randn(shape, dtype=dtype)
1376*da0073e9SAndroid Build Coastguard Worker
1377*da0073e9SAndroid Build Coastguard Worker            B0, B1 = 7, 11
1378*da0073e9SAndroid Build Coastguard Worker            test = self._vmap_test
1379*da0073e9SAndroid Build Coastguard Worker
1380*da0073e9SAndroid Build Coastguard Worker            # Single vmap, various in_dims / out_dims
1381*da0073e9SAndroid Build Coastguard Worker            test(op, [get([B0, 3])])
1382*da0073e9SAndroid Build Coastguard Worker            test(op, [get([2, 5, B0, 3])], in_dims=2)
1383*da0073e9SAndroid Build Coastguard Worker            test(op, [get([2, 5, B0, 3])], in_dims=2, out_dims=2)
1384*da0073e9SAndroid Build Coastguard Worker
1385*da0073e9SAndroid Build Coastguard Worker            # Doubly nested vmap
1386*da0073e9SAndroid Build Coastguard Worker            test(vmap(op), [get([B0, B1])])
1387*da0073e9SAndroid Build Coastguard Worker            test(vmap(op), [get([B1, 2, 5, B0, 3])], in_dims=2)
1388*da0073e9SAndroid Build Coastguard Worker            test(vmap(op, in_dims=2), [get([2, 5, B0, B1, 3])], in_dims=2, out_dims=2)
1389*da0073e9SAndroid Build Coastguard Worker
1390*da0073e9SAndroid Build Coastguard Worker        # correctness tests
1391*da0073e9SAndroid Build Coastguard Worker        run_test(torch.float)
1392*da0073e9SAndroid Build Coastguard Worker        run_test(torch.cfloat)
1393*da0073e9SAndroid Build Coastguard Worker
1394*da0073e9SAndroid Build Coastguard Worker        # check that torch.conj on a non-complex tensor returns the same tensor
1395*da0073e9SAndroid Build Coastguard Worker        real_tensor = torch.randn(3)
1396*da0073e9SAndroid Build Coastguard Worker        result = vmap(op)(real_tensor)
1397*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result.data_ptr(), real_tensor.data_ptr())
1398*da0073e9SAndroid Build Coastguard Worker
1399*da0073e9SAndroid Build Coastguard Worker    def test_contiguous(self):
1400*da0073e9SAndroid Build Coastguard Worker        op = Tensor.contiguous
1401*da0073e9SAndroid Build Coastguard Worker
1402*da0073e9SAndroid Build Coastguard Worker        self._test_unary(op, TensorFactory.randn, "cpu")
1403*da0073e9SAndroid Build Coastguard Worker
1404*da0073e9SAndroid Build Coastguard Worker        # check that contiguous returns the original tensor if the per-examples
1405*da0073e9SAndroid Build Coastguard Worker        # are already contiguous
1406*da0073e9SAndroid Build Coastguard Worker        B0 = 3
1407*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(B0, 2, 5, 7)
1408*da0073e9SAndroid Build Coastguard Worker        x = x.movedim(0, 2)
1409*da0073e9SAndroid Build Coastguard Worker        result = vmap(Tensor.contiguous, in_dims=2, out_dims=2)(x)
1410*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(result is x)
1411*da0073e9SAndroid Build Coastguard Worker
1412*da0073e9SAndroid Build Coastguard Worker        msg = "NYI: querying is_contiguous inside of vmap for memory_format"
1413*da0073e9SAndroid Build Coastguard Worker        tensor = torch.randn(B0, 3)
1414*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, msg):
1415*da0073e9SAndroid Build Coastguard Worker            vmap(functools.partial(op, memory_format=torch.channels_last))(tensor)
1416*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, msg):
1417*da0073e9SAndroid Build Coastguard Worker            vmap(functools.partial(op, memory_format=torch.channels_last_3d))(tensor)
1418*da0073e9SAndroid Build Coastguard Worker
1419*da0073e9SAndroid Build Coastguard Worker    def test_stride(self):
1420*da0073e9SAndroid Build Coastguard Worker        B0 = 3
1421*da0073e9SAndroid Build Coastguard Worker
1422*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(B0, 2, 5, 7)
1423*da0073e9SAndroid Build Coastguard Worker
1424*da0073e9SAndroid Build Coastguard Worker        def foo(x):
1425*da0073e9SAndroid Build Coastguard Worker            assert x.stride() == (7 * 5, 7, 1)
1426*da0073e9SAndroid Build Coastguard Worker            return x
1427*da0073e9SAndroid Build Coastguard Worker
1428*da0073e9SAndroid Build Coastguard Worker        vmap(foo)(x)
1429*da0073e9SAndroid Build Coastguard Worker
1430*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, B0, 5, 7).movedim(1, 0)
1431*da0073e9SAndroid Build Coastguard Worker
1432*da0073e9SAndroid Build Coastguard Worker        def bar(x):
1433*da0073e9SAndroid Build Coastguard Worker            assert x.stride() == (7 * 5 * B0, 7, 1)
1434*da0073e9SAndroid Build Coastguard Worker            return x
1435*da0073e9SAndroid Build Coastguard Worker
1436*da0073e9SAndroid Build Coastguard Worker        vmap(bar)(x)
1437*da0073e9SAndroid Build Coastguard Worker
1438*da0073e9SAndroid Build Coastguard Worker    def test_chunk(self):
1439*da0073e9SAndroid Build Coastguard Worker        test = self._vmap_view_test
1440*da0073e9SAndroid Build Coastguard Worker        op = torch.chunk
1441*da0073e9SAndroid Build Coastguard Worker        B0, B1, B2 = 7, 11, 13
1442*da0073e9SAndroid Build Coastguard Worker
1443*da0073e9SAndroid Build Coastguard Worker        # tests for torch.split(self, split_size: int, dim)
1444*da0073e9SAndroid Build Coastguard Worker        test(op, (torch.rand(B0, 2, 1024), 15, -1), in_dims=(0, None, None))
1445*da0073e9SAndroid Build Coastguard Worker        test(op, (torch.rand(2, B0, 1024), 9, 1), in_dims=(1, None, None))
1446*da0073e9SAndroid Build Coastguard Worker        test(
1447*da0073e9SAndroid Build Coastguard Worker            vmap(op, in_dims=(0, None, None)),
1448*da0073e9SAndroid Build Coastguard Worker            (torch.rand(B1, 1023, B0, 5), 4, 0),
1449*da0073e9SAndroid Build Coastguard Worker            in_dims=(2, None, None),
1450*da0073e9SAndroid Build Coastguard Worker        )
1451*da0073e9SAndroid Build Coastguard Worker        test(
1452*da0073e9SAndroid Build Coastguard Worker            vmap(vmap(lambda t: op(t, 4, 1), in_dims=2)),
1453*da0073e9SAndroid Build Coastguard Worker            (torch.rand(B1, 2, B0, 64, B2),),
1454*da0073e9SAndroid Build Coastguard Worker            in_dims=2,
1455*da0073e9SAndroid Build Coastguard Worker        )
1456*da0073e9SAndroid Build Coastguard Worker
1457*da0073e9SAndroid Build Coastguard Worker    def test_clamp(self):
1458*da0073e9SAndroid Build Coastguard Worker        clamp_cases = (
1459*da0073e9SAndroid Build Coastguard Worker            (lambda t: t.clamp(min=-0.5), TensorFactory.randn),
1460*da0073e9SAndroid Build Coastguard Worker            (lambda t: t.clamp(max=0.5), TensorFactory.randn),
1461*da0073e9SAndroid Build Coastguard Worker            (lambda t: t.clamp(min=-0.5, max=0.5), TensorFactory.randn),
1462*da0073e9SAndroid Build Coastguard Worker            (lambda t: t.clamp_min(min=-0.5), TensorFactory.randn),
1463*da0073e9SAndroid Build Coastguard Worker            (lambda t: t.clamp_max(max=0.5), TensorFactory.randn),
1464*da0073e9SAndroid Build Coastguard Worker        )
1465*da0073e9SAndroid Build Coastguard Worker        for op, getter in clamp_cases:
1466*da0073e9SAndroid Build Coastguard Worker            self._test_unary(op, getter, "cpu")
1467*da0073e9SAndroid Build Coastguard Worker
1468*da0073e9SAndroid Build Coastguard Worker    def test_comparison_ops(self):
1469*da0073e9SAndroid Build Coastguard Worker        test = functools.partial(self._vmap_test, check_propagates_grad=False)
1470*da0073e9SAndroid Build Coastguard Worker
1471*da0073e9SAndroid Build Coastguard Worker        getter = TensorFactory.randn
1472*da0073e9SAndroid Build Coastguard Worker        B0, B1 = 7, 11
1473*da0073e9SAndroid Build Coastguard Worker
1474*da0073e9SAndroid Build Coastguard Worker        ops = (
1475*da0073e9SAndroid Build Coastguard Worker            torch.eq,
1476*da0073e9SAndroid Build Coastguard Worker            lambda x, y: x == y,
1477*da0073e9SAndroid Build Coastguard Worker            torch.gt,
1478*da0073e9SAndroid Build Coastguard Worker            lambda x, y: x > y,
1479*da0073e9SAndroid Build Coastguard Worker            torch.ge,
1480*da0073e9SAndroid Build Coastguard Worker            lambda x, y: x >= y,
1481*da0073e9SAndroid Build Coastguard Worker            torch.le,
1482*da0073e9SAndroid Build Coastguard Worker            lambda x, y: x <= y,
1483*da0073e9SAndroid Build Coastguard Worker            torch.lt,
1484*da0073e9SAndroid Build Coastguard Worker            lambda x, y: x < y,
1485*da0073e9SAndroid Build Coastguard Worker            torch.ne,
1486*da0073e9SAndroid Build Coastguard Worker            lambda x, y: x != y,
1487*da0073e9SAndroid Build Coastguard Worker        )
1488*da0073e9SAndroid Build Coastguard Worker
1489*da0073e9SAndroid Build Coastguard Worker        for op in ops:
1490*da0073e9SAndroid Build Coastguard Worker            # Single vmap: op(Tensor, Tensor)
1491*da0073e9SAndroid Build Coastguard Worker            test(op, (getter([B0, 3]), getter([B0, 3])))
1492*da0073e9SAndroid Build Coastguard Worker            test(op, (getter([B0]), getter([B0, 2, 3])))
1493*da0073e9SAndroid Build Coastguard Worker            test(op, (getter([B0]), getter([2, B0, 3])), in_dims=(0, 1))
1494*da0073e9SAndroid Build Coastguard Worker            test(op, (getter([B0]), getter([2, B0, 3])), in_dims=(0, 1), out_dims=1)
1495*da0073e9SAndroid Build Coastguard Worker            test(op, (getter([B0]), getter([2, 3])), in_dims=(0, None))
1496*da0073e9SAndroid Build Coastguard Worker            test(op, (getter([2, 3]), getter([B0, 3])), in_dims=(0, None))
1497*da0073e9SAndroid Build Coastguard Worker
1498*da0073e9SAndroid Build Coastguard Worker            # Nested vmap: op(Tensor, Tensor)
1499*da0073e9SAndroid Build Coastguard Worker            test(vmap(op), (getter([B0, B1, 2, 3]), getter([B0, B1, 3])))
1500*da0073e9SAndroid Build Coastguard Worker            test(
1501*da0073e9SAndroid Build Coastguard Worker                vmap(op, in_dims=(None, 0)),
1502*da0073e9SAndroid Build Coastguard Worker                (getter([B0, 2, 3]), getter([B1, 3])),
1503*da0073e9SAndroid Build Coastguard Worker                in_dims=(0, None),
1504*da0073e9SAndroid Build Coastguard Worker            )
1505*da0073e9SAndroid Build Coastguard Worker
1506*da0073e9SAndroid Build Coastguard Worker            # test number as inputs
1507*da0073e9SAndroid Build Coastguard Worker            number = getter([]).item()
1508*da0073e9SAndroid Build Coastguard Worker            self._test_unary(
1509*da0073e9SAndroid Build Coastguard Worker                lambda t: op(t, number), getter, "cpu", check_propagates_grad=False
1510*da0073e9SAndroid Build Coastguard Worker            )
1511*da0073e9SAndroid Build Coastguard Worker
1512*da0073e9SAndroid Build Coastguard Worker    def test_diagonal(self):
1513*da0073e9SAndroid Build Coastguard Worker        tensor = torch.randn(3, 5, 7, 11, 13)
1514*da0073e9SAndroid Build Coastguard Worker        test = self._vmap_view_test
1515*da0073e9SAndroid Build Coastguard Worker        op = torch.diagonal
1516*da0073e9SAndroid Build Coastguard Worker        test(op, (tensor, 1, 0, 1), in_dims=(0, None, None, None))
1517*da0073e9SAndroid Build Coastguard Worker        test(op, (tensor, 0, 2, -1), in_dims=(0, None, None, None))
1518*da0073e9SAndroid Build Coastguard Worker        test(op, (tensor, 2, 1, 2), in_dims=(1, None, None, None))
1519*da0073e9SAndroid Build Coastguard Worker        test(op, (tensor, 0, -2, -1), in_dims=(1, None, None, None), out_dims=1)
1520*da0073e9SAndroid Build Coastguard Worker        test(vmap(lambda t: op(t, 0, 0, -1)), (tensor,), in_dims=1, out_dims=1)
1521*da0073e9SAndroid Build Coastguard Worker        test(
1522*da0073e9SAndroid Build Coastguard Worker            vmap(vmap(lambda t: op(t, 0, 0, 1), in_dims=1), in_dims=3),
1523*da0073e9SAndroid Build Coastguard Worker            (tensor,),
1524*da0073e9SAndroid Build Coastguard Worker            in_dims=1,
1525*da0073e9SAndroid Build Coastguard Worker            out_dims=1,
1526*da0073e9SAndroid Build Coastguard Worker        )
1527*da0073e9SAndroid Build Coastguard Worker
1528*da0073e9SAndroid Build Coastguard Worker    def test_dot(self):
1529*da0073e9SAndroid Build Coastguard Worker        op = torch.dot
1530*da0073e9SAndroid Build Coastguard Worker        test = self._vmap_test
1531*da0073e9SAndroid Build Coastguard Worker        B0, B1 = 7, 11
1532*da0073e9SAndroid Build Coastguard Worker
1533*da0073e9SAndroid Build Coastguard Worker        # shape mismatch
1534*da0073e9SAndroid Build Coastguard Worker        msg = "Shape mismatch"
1535*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, msg):
1536*da0073e9SAndroid Build Coastguard Worker            vmap(op)(torch.randn(B0, 2, 2, 2), torch.randn(B0, 2))
1537*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, msg):
1538*da0073e9SAndroid Build Coastguard Worker            vmap(op, in_dims=(0, None))(torch.randn(B0, 2), torch.randn(2, 2))
1539*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, msg):
1540*da0073e9SAndroid Build Coastguard Worker            vmap(op, in_dims=(None, 0))(torch.randn(2, 2), torch.randn(B0, 2))
1541*da0073e9SAndroid Build Coastguard Worker
1542*da0073e9SAndroid Build Coastguard Worker        # left arg is vmapped
1543*da0073e9SAndroid Build Coastguard Worker        test(op, (torch.rand(B0, 5), torch.rand(5)), in_dims=(0, None))
1544*da0073e9SAndroid Build Coastguard Worker        test(
1545*da0073e9SAndroid Build Coastguard Worker            vmap(op, in_dims=(0, None)),
1546*da0073e9SAndroid Build Coastguard Worker            (torch.rand(B1, B0, 5), torch.rand(5)),
1547*da0073e9SAndroid Build Coastguard Worker            in_dims=(1, None),
1548*da0073e9SAndroid Build Coastguard Worker        )
1549*da0073e9SAndroid Build Coastguard Worker
1550*da0073e9SAndroid Build Coastguard Worker        # right arg is vmapped
1551*da0073e9SAndroid Build Coastguard Worker        test(op, (torch.rand(5), torch.rand(B0, 5)), in_dims=(None, 0))
1552*da0073e9SAndroid Build Coastguard Worker        test(
1553*da0073e9SAndroid Build Coastguard Worker            vmap(op, in_dims=(None, 0)),
1554*da0073e9SAndroid Build Coastguard Worker            (torch.rand(5), torch.rand(B1, B0, 5)),
1555*da0073e9SAndroid Build Coastguard Worker            in_dims=(None, 1),
1556*da0073e9SAndroid Build Coastguard Worker        )
1557*da0073e9SAndroid Build Coastguard Worker
1558*da0073e9SAndroid Build Coastguard Worker        # both args are vmapped
1559*da0073e9SAndroid Build Coastguard Worker        test(op, (torch.rand(B0, 5), torch.rand(B0, 5)))
1560*da0073e9SAndroid Build Coastguard Worker        test(vmap(op), (torch.rand(B1, B0, 5), torch.rand(B0, B1, 5)), in_dims=(1, 0))
1561*da0073e9SAndroid Build Coastguard Worker        test(
1562*da0073e9SAndroid Build Coastguard Worker            vmap(op, in_dims=(0, None)),
1563*da0073e9SAndroid Build Coastguard Worker            (torch.rand(B1, 5), torch.rand(B0, 5)),
1564*da0073e9SAndroid Build Coastguard Worker            in_dims=(None, 0),
1565*da0073e9SAndroid Build Coastguard Worker        )
1566*da0073e9SAndroid Build Coastguard Worker
1567*da0073e9SAndroid Build Coastguard Worker    def test_expand_as(self):
1568*da0073e9SAndroid Build Coastguard Worker        op = torch.Tensor.expand_as
1569*da0073e9SAndroid Build Coastguard Worker        test = self._vmap_view_test
1570*da0073e9SAndroid Build Coastguard Worker        B0, B1, B2 = 7, 11, 13
1571*da0073e9SAndroid Build Coastguard Worker        test(op, (torch.rand(B0, 1, 5), torch.rand(B0, 2, 3, 5)))
1572*da0073e9SAndroid Build Coastguard Worker        test(op, (torch.rand(B0, 1, 5), torch.rand(2, 3, 5)), in_dims=(0, None))
1573*da0073e9SAndroid Build Coastguard Worker        test(op, (torch.rand(1, 5), torch.rand(B0, 2, 3, 5)), in_dims=(None, 0))
1574*da0073e9SAndroid Build Coastguard Worker        test(vmap(op), (torch.rand(B0, B1, 1, 5), torch.rand(B0, B1, 2, 3, 5)))
1575*da0073e9SAndroid Build Coastguard Worker        test(
1576*da0073e9SAndroid Build Coastguard Worker            vmap(op),
1577*da0073e9SAndroid Build Coastguard Worker            (torch.rand(B0, B1, 1, 5), torch.rand(B1, B0, 2, 3, 5)),
1578*da0073e9SAndroid Build Coastguard Worker            in_dims=(0, 1),
1579*da0073e9SAndroid Build Coastguard Worker        )
1580*da0073e9SAndroid Build Coastguard Worker        test(vmap(op), (torch.rand(B0, B1), torch.rand(B1, 2, 3, 5)), in_dims=(0, None))
1581*da0073e9SAndroid Build Coastguard Worker        test(vmap(vmap(op)), (torch.rand(B0, B1, B2), torch.rand(B0, B1, B2, 2, 3, 5)))
1582*da0073e9SAndroid Build Coastguard Worker
1583*da0073e9SAndroid Build Coastguard Worker    def test_fill_and_zero_inplace(self):
1584*da0073e9SAndroid Build Coastguard Worker        test = functools.partial(self._vmap_test, check_propagates_grad=False)
1585*da0073e9SAndroid Build Coastguard Worker        B0, B1 = 7, 11
1586*da0073e9SAndroid Build Coastguard Worker        ops = (
1587*da0073e9SAndroid Build Coastguard Worker            lambda t: t.fill_(0.1),
1588*da0073e9SAndroid Build Coastguard Worker            lambda t: t.fill_(torch.tensor(0.2)),
1589*da0073e9SAndroid Build Coastguard Worker            lambda t: t.zero_(),
1590*da0073e9SAndroid Build Coastguard Worker        )
1591*da0073e9SAndroid Build Coastguard Worker
1592*da0073e9SAndroid Build Coastguard Worker        for op in ops:
1593*da0073e9SAndroid Build Coastguard Worker            # Single vmap, various in_dims / out_dims
1594*da0073e9SAndroid Build Coastguard Worker            test(op, [TensorFactory.randn([B0, 3])])
1595*da0073e9SAndroid Build Coastguard Worker            test(op, [TensorFactory.randn([2, 5, B0, 3])], in_dims=2)
1596*da0073e9SAndroid Build Coastguard Worker            test(op, [TensorFactory.randn([2, 5, B0, 3])], in_dims=2, out_dims=2)
1597*da0073e9SAndroid Build Coastguard Worker
1598*da0073e9SAndroid Build Coastguard Worker            # Doubly nested vmap
1599*da0073e9SAndroid Build Coastguard Worker            test(vmap(op), [TensorFactory.randn([B0, B1])])
1600*da0073e9SAndroid Build Coastguard Worker            test(vmap(op), [TensorFactory.randn([B1, 2, 5, B0, 3])], in_dims=2)
1601*da0073e9SAndroid Build Coastguard Worker            test(
1602*da0073e9SAndroid Build Coastguard Worker                vmap(op, in_dims=2),
1603*da0073e9SAndroid Build Coastguard Worker                [TensorFactory.randn([2, 5, B0, B1, 3])],
1604*da0073e9SAndroid Build Coastguard Worker                in_dims=2,
1605*da0073e9SAndroid Build Coastguard Worker                out_dims=2,
1606*da0073e9SAndroid Build Coastguard Worker            )
1607*da0073e9SAndroid Build Coastguard Worker
1608*da0073e9SAndroid Build Coastguard Worker        # test when value is a batched tensor for fill_ operator
1609*da0073e9SAndroid Build Coastguard Worker        B0, B1 = 3, 5
1610*da0073e9SAndroid Build Coastguard Worker        test(Tensor.fill_, [TensorFactory.randn([B0, B1]), TensorFactory.randn(B0)])
1611*da0073e9SAndroid Build Coastguard Worker
1612*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1613*da0073e9SAndroid Build Coastguard Worker            RuntimeError, r"output with shape .+ doesn't match the broadcast shape"
1614*da0073e9SAndroid Build Coastguard Worker        ):
1615*da0073e9SAndroid Build Coastguard Worker            # Runtime Error is thrown when the tensor being written to isn't being vmapped over
1616*da0073e9SAndroid Build Coastguard Worker            vmap(Tensor.fill_, (None, 0))(
1617*da0073e9SAndroid Build Coastguard Worker                TensorFactory.randn([B0, B1]), TensorFactory.randn([B0])
1618*da0073e9SAndroid Build Coastguard Worker            )
1619*da0073e9SAndroid Build Coastguard Worker
1620*da0073e9SAndroid Build Coastguard Worker    def _test_complex_views(self, op, dtypes):
1621*da0073e9SAndroid Build Coastguard Worker        test = self._vmap_view_test
1622*da0073e9SAndroid Build Coastguard Worker
1623*da0073e9SAndroid Build Coastguard Worker        def run_test(op, dtype):
1624*da0073e9SAndroid Build Coastguard Worker            def get(shape):
1625*da0073e9SAndroid Build Coastguard Worker                return torch.randn(shape, dtype=dtype)
1626*da0073e9SAndroid Build Coastguard Worker
1627*da0073e9SAndroid Build Coastguard Worker            B0, B1 = 7, 11
1628*da0073e9SAndroid Build Coastguard Worker
1629*da0073e9SAndroid Build Coastguard Worker            # Single vmap, various in_dims / out_dims
1630*da0073e9SAndroid Build Coastguard Worker            test(op, [get([B0, 3])])
1631*da0073e9SAndroid Build Coastguard Worker            test(op, [get([3, B0])], in_dims=1)
1632*da0073e9SAndroid Build Coastguard Worker            test(op, [get([2, 5, B0, 3])], in_dims=2)
1633*da0073e9SAndroid Build Coastguard Worker            test(op, [get([2, 5, B0, 3])], in_dims=2, out_dims=2)
1634*da0073e9SAndroid Build Coastguard Worker
1635*da0073e9SAndroid Build Coastguard Worker            # Doubly nested vmap
1636*da0073e9SAndroid Build Coastguard Worker            test(vmap(op), [get([B0, B1])])
1637*da0073e9SAndroid Build Coastguard Worker            test(vmap(op), [get([B1, 2, 5, 3, B0])], in_dims=4)
1638*da0073e9SAndroid Build Coastguard Worker            test(vmap(op, in_dims=2), [get([2, 5, B0, B1, 3])], in_dims=2, out_dims=2)
1639*da0073e9SAndroid Build Coastguard Worker
1640*da0073e9SAndroid Build Coastguard Worker        for dtype in dtypes:
1641*da0073e9SAndroid Build Coastguard Worker            run_test(op, dtype)
1642*da0073e9SAndroid Build Coastguard Worker
1643*da0073e9SAndroid Build Coastguard Worker    def test_real(self):
1644*da0073e9SAndroid Build Coastguard Worker        self._test_complex_views(torch.real, dtypes=[torch.cfloat, torch.cdouble])
1645*da0073e9SAndroid Build Coastguard Worker
1646*da0073e9SAndroid Build Coastguard Worker    def test_imag(self):
1647*da0073e9SAndroid Build Coastguard Worker        self._test_complex_views(torch.imag, dtypes=[torch.cfloat, torch.cdouble])
1648*da0073e9SAndroid Build Coastguard Worker
1649*da0073e9SAndroid Build Coastguard Worker    def test_view_as_real(self):
1650*da0073e9SAndroid Build Coastguard Worker        self._test_complex_views(
1651*da0073e9SAndroid Build Coastguard Worker            torch.view_as_real, dtypes=[torch.cfloat, torch.cdouble]
1652*da0073e9SAndroid Build Coastguard Worker        )
1653*da0073e9SAndroid Build Coastguard Worker
1654*da0073e9SAndroid Build Coastguard Worker    def test_view_as_complex(self):
1655*da0073e9SAndroid Build Coastguard Worker        def run_test(dtype):
1656*da0073e9SAndroid Build Coastguard Worker            def get(shape):
1657*da0073e9SAndroid Build Coastguard Worker                return torch.randn(shape, dtype=dtype)
1658*da0073e9SAndroid Build Coastguard Worker
1659*da0073e9SAndroid Build Coastguard Worker            op = torch.view_as_complex
1660*da0073e9SAndroid Build Coastguard Worker            test = self._vmap_view_test
1661*da0073e9SAndroid Build Coastguard Worker            B0, B1 = 7, 11
1662*da0073e9SAndroid Build Coastguard Worker
1663*da0073e9SAndroid Build Coastguard Worker            # Single vmap, various in_dims / out_dims
1664*da0073e9SAndroid Build Coastguard Worker            test(op, [get([B0, 3, 2])])
1665*da0073e9SAndroid Build Coastguard Worker            test(op, [get([2, 5, B0, 3, 2])], in_dims=2)
1666*da0073e9SAndroid Build Coastguard Worker            test(op, [get([2, 5, B0, 3, 2])], in_dims=2, out_dims=2)
1667*da0073e9SAndroid Build Coastguard Worker
1668*da0073e9SAndroid Build Coastguard Worker            # Doubly nested vmap
1669*da0073e9SAndroid Build Coastguard Worker            test(vmap(op), [get([B0, B1, 2])])
1670*da0073e9SAndroid Build Coastguard Worker            test(vmap(op), [get([B1, 2, 5, B0, 3, 2])], in_dims=2)
1671*da0073e9SAndroid Build Coastguard Worker            test(
1672*da0073e9SAndroid Build Coastguard Worker                vmap(op, in_dims=2), [get([2, 5, B0, B1, 3, 2])], in_dims=2, out_dims=2
1673*da0073e9SAndroid Build Coastguard Worker            )
1674*da0073e9SAndroid Build Coastguard Worker
1675*da0073e9SAndroid Build Coastguard Worker            # Interesting case #1: Batch dim directly before dim of size 2
1676*da0073e9SAndroid Build Coastguard Worker            test(op, [get([3, B0, 2])], in_dims=1)
1677*da0073e9SAndroid Build Coastguard Worker            test(vmap(op, in_dims=1), [get([3, B1, B0, 2])], in_dims=2)
1678*da0073e9SAndroid Build Coastguard Worker
1679*da0073e9SAndroid Build Coastguard Worker            # Interesting case #2: Batch dim at end of tensor, success cases
1680*da0073e9SAndroid Build Coastguard Worker            # view_as_complex requires that the dim with size 2 have stride 1
1681*da0073e9SAndroid Build Coastguard Worker            # in order for the view to function propertly
1682*da0073e9SAndroid Build Coastguard Worker            test(op, [get([B0, 2]).transpose(0, 1)], in_dims=1)
1683*da0073e9SAndroid Build Coastguard Worker            test(vmap(op, in_dims=1), [get([B0, B1, 2]).movedim(1, 2)])
1684*da0073e9SAndroid Build Coastguard Worker            test(vmap(op, in_dims=2), [get([B0, 3, B1, 2]).movedim(2, 3)])
1685*da0073e9SAndroid Build Coastguard Worker
1686*da0073e9SAndroid Build Coastguard Worker            # Interesting case #3: Batch dim at end of tensor, failure cases
1687*da0073e9SAndroid Build Coastguard Worker            msg = "Tensor must have a last dimension with stride 1"
1688*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, msg):
1689*da0073e9SAndroid Build Coastguard Worker                vmap(op, in_dims=1)(get([2, B0]))
1690*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, msg):
1691*da0073e9SAndroid Build Coastguard Worker                vmap(vmap(op, in_dims=1), in_dims=1)(get([2, B0, B1]))
1692*da0073e9SAndroid Build Coastguard Worker
1693*da0073e9SAndroid Build Coastguard Worker            # Invalid input: no dimension of size 2
1694*da0073e9SAndroid Build Coastguard Worker            msg = "Input tensor must have one or more dimensions"
1695*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, msg):
1696*da0073e9SAndroid Build Coastguard Worker                vmap(op)(get([B0]))
1697*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, msg):
1698*da0073e9SAndroid Build Coastguard Worker                vmap(vmap(op))(get([B0, B1]))
1699*da0073e9SAndroid Build Coastguard Worker
1700*da0073e9SAndroid Build Coastguard Worker            # Invalid input: Batch dim has size 2, but the logical last dim does
1701*da0073e9SAndroid Build Coastguard Worker            # not have size 2
1702*da0073e9SAndroid Build Coastguard Worker            msg = "Tensor must have a last dimension of size 2"
1703*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, msg):
1704*da0073e9SAndroid Build Coastguard Worker                vmap(op, in_dims=1)(get([3, 2]))
1705*da0073e9SAndroid Build Coastguard Worker
1706*da0073e9SAndroid Build Coastguard Worker        for dtype in [torch.float, torch.double]:
1707*da0073e9SAndroid Build Coastguard Worker            run_test(dtype)
1708*da0073e9SAndroid Build Coastguard Worker
1709*da0073e9SAndroid Build Coastguard Worker    def test_is_complex(self):
1710*da0073e9SAndroid Build Coastguard Worker        ctensor = torch.randn(3, dtype=torch.cfloat)
1711*da0073e9SAndroid Build Coastguard Worker        tensor = torch.randn(3)
1712*da0073e9SAndroid Build Coastguard Worker
1713*da0073e9SAndroid Build Coastguard Worker        def foo(x):
1714*da0073e9SAndroid Build Coastguard Worker            if x.is_complex():
1715*da0073e9SAndroid Build Coastguard Worker                return torch.tensor(1)
1716*da0073e9SAndroid Build Coastguard Worker            else:
1717*da0073e9SAndroid Build Coastguard Worker                return torch.tensor(0)
1718*da0073e9SAndroid Build Coastguard Worker
1719*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(vmap(foo)(ctensor), torch.tensor([1, 1, 1]))
1720*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(vmap(foo)(tensor), torch.tensor([0, 0, 0]))
1721*da0073e9SAndroid Build Coastguard Worker
1722*da0073e9SAndroid Build Coastguard Worker    def test_is_floating_point(self):
1723*da0073e9SAndroid Build Coastguard Worker        float_tensor = torch.tensor([1.0, 2.0, 3.0])
1724*da0073e9SAndroid Build Coastguard Worker        long_tensor = torch.tensor([1, 2, 3])
1725*da0073e9SAndroid Build Coastguard Worker
1726*da0073e9SAndroid Build Coastguard Worker        def foo(x):
1727*da0073e9SAndroid Build Coastguard Worker            if x.is_floating_point():
1728*da0073e9SAndroid Build Coastguard Worker                return torch.tensor(1)
1729*da0073e9SAndroid Build Coastguard Worker            else:
1730*da0073e9SAndroid Build Coastguard Worker                return torch.tensor(0)
1731*da0073e9SAndroid Build Coastguard Worker
1732*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(vmap(foo)(float_tensor), torch.tensor([1, 1, 1]))
1733*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(vmap(foo)(long_tensor), torch.tensor([0, 0, 0]))
1734*da0073e9SAndroid Build Coastguard Worker
1735*da0073e9SAndroid Build Coastguard Worker    def test_is_contiguous(self):
1736*da0073e9SAndroid Build Coastguard Worker        def foo(x):
1737*da0073e9SAndroid Build Coastguard Worker            if x.is_contiguous():
1738*da0073e9SAndroid Build Coastguard Worker                return torch.tensor(1.0)
1739*da0073e9SAndroid Build Coastguard Worker            else:
1740*da0073e9SAndroid Build Coastguard Worker                return torch.tensor(0.0)
1741*da0073e9SAndroid Build Coastguard Worker
1742*da0073e9SAndroid Build Coastguard Worker        B0, B1 = 3, 5
1743*da0073e9SAndroid Build Coastguard Worker
1744*da0073e9SAndroid Build Coastguard Worker        # Single batch dim
1745*da0073e9SAndroid Build Coastguard Worker        contig = torch.randn(B0, 2, 7)
1746*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(vmap(foo)(contig), torch.ones(B0))
1747*da0073e9SAndroid Build Coastguard Worker
1748*da0073e9SAndroid Build Coastguard Worker        noncontig = torch.randn(2, B0, 7)
1749*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(vmap(foo, in_dims=1)(noncontig), torch.zeros(B0))
1750*da0073e9SAndroid Build Coastguard Worker
1751*da0073e9SAndroid Build Coastguard Worker        noncontig = torch.randn(2, B0, 7).movedim(1, 0)
1752*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(vmap(foo)(noncontig), torch.zeros(B0))
1753*da0073e9SAndroid Build Coastguard Worker
1754*da0073e9SAndroid Build Coastguard Worker        noncontig = torch.randn(2, 7, B0)
1755*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(vmap(foo, in_dims=2)(noncontig), torch.zeros(B0))
1756*da0073e9SAndroid Build Coastguard Worker
1757*da0073e9SAndroid Build Coastguard Worker        # Multiple batch dims
1758*da0073e9SAndroid Build Coastguard Worker        contig = torch.randn(B0, B1, 3)
1759*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(vmap(vmap(foo))(contig), torch.ones(B0, B1))
1760*da0073e9SAndroid Build Coastguard Worker
1761*da0073e9SAndroid Build Coastguard Worker        contig = torch.randn(B1, B0, 3)
1762*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(vmap(vmap(foo), in_dims=1)(contig), torch.ones(B0, B1))
1763*da0073e9SAndroid Build Coastguard Worker
1764*da0073e9SAndroid Build Coastguard Worker        contig = torch.randn(B1, B0, 3).movedim(0, 1)
1765*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(vmap(vmap(foo))(contig), torch.ones(B0, B1))
1766*da0073e9SAndroid Build Coastguard Worker
1767*da0073e9SAndroid Build Coastguard Worker        noncontig = torch.randn(B0, 3, B1)
1768*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(vmap(vmap(foo, in_dims=1))(noncontig), torch.zeros(B0, B1))
1769*da0073e9SAndroid Build Coastguard Worker
1770*da0073e9SAndroid Build Coastguard Worker        # is_contiguous on empty tensor is True
1771*da0073e9SAndroid Build Coastguard Worker        def bar(x):
1772*da0073e9SAndroid Build Coastguard Worker            assert x.is_contiguous()
1773*da0073e9SAndroid Build Coastguard Worker            return x
1774*da0073e9SAndroid Build Coastguard Worker
1775*da0073e9SAndroid Build Coastguard Worker        vmap(bar)(torch.randn(B0, 0, 3))
1776*da0073e9SAndroid Build Coastguard Worker        vmap(bar, in_dims=1)(torch.randn(0, B0, 3))
1777*da0073e9SAndroid Build Coastguard Worker        vmap(bar)(torch.randn(B0, 0, 3).mT)
1778*da0073e9SAndroid Build Coastguard Worker
1779*da0073e9SAndroid Build Coastguard Worker        # is_contiguous with other memory formats
1780*da0073e9SAndroid Build Coastguard Worker        def baz(x, memory_format):
1781*da0073e9SAndroid Build Coastguard Worker            x.is_contiguous(memory_format=memory_format)
1782*da0073e9SAndroid Build Coastguard Worker            return x
1783*da0073e9SAndroid Build Coastguard Worker
1784*da0073e9SAndroid Build Coastguard Worker        msg = "NYI: querying is_contiguous inside of vmap for memory_format"
1785*da0073e9SAndroid Build Coastguard Worker        tensor = torch.randn(B0, 2, 7, 3)
1786*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, msg):
1787*da0073e9SAndroid Build Coastguard Worker            vmap(functools.partial(baz, memory_format=torch.channels_last))(tensor)
1788*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, msg):
1789*da0073e9SAndroid Build Coastguard Worker            vmap(functools.partial(baz, memory_format=torch.channels_last_3d))(tensor)
1790*da0073e9SAndroid Build Coastguard Worker
1791*da0073e9SAndroid Build Coastguard Worker    def test_movedim(self):
1792*da0073e9SAndroid Build Coastguard Worker        op = torch.movedim
1793*da0073e9SAndroid Build Coastguard Worker        test = self._vmap_view_test
1794*da0073e9SAndroid Build Coastguard Worker        B0, B1, B2 = 7, 11, 13
1795*da0073e9SAndroid Build Coastguard Worker
1796*da0073e9SAndroid Build Coastguard Worker        # movedim(tensor, int, int) variant
1797*da0073e9SAndroid Build Coastguard Worker        test(op, (torch.rand(B0, 2, 5), 0, 1), in_dims=(0, None, None))
1798*da0073e9SAndroid Build Coastguard Worker        test(op, (torch.rand(2, B0, 5), 0, 1), in_dims=(1, None, None))
1799*da0073e9SAndroid Build Coastguard Worker        test(
1800*da0073e9SAndroid Build Coastguard Worker            vmap(op, in_dims=(0, None, None)),
1801*da0073e9SAndroid Build Coastguard Worker            (torch.rand(B1, 2, B0, 5), 0, 1),
1802*da0073e9SAndroid Build Coastguard Worker            in_dims=(2, None, None),
1803*da0073e9SAndroid Build Coastguard Worker        )
1804*da0073e9SAndroid Build Coastguard Worker        test(
1805*da0073e9SAndroid Build Coastguard Worker            vmap(vmap(op, in_dims=(2, None, None)), in_dims=(0, None, None)),
1806*da0073e9SAndroid Build Coastguard Worker            (torch.rand(B1, 2, B0, 5, B2), 0, 1),
1807*da0073e9SAndroid Build Coastguard Worker            in_dims=(2, None, None),
1808*da0073e9SAndroid Build Coastguard Worker        )
1809*da0073e9SAndroid Build Coastguard Worker
1810*da0073e9SAndroid Build Coastguard Worker        # movedim(tensor, intlist, intlist) variant
1811*da0073e9SAndroid Build Coastguard Worker        test(op, (torch.rand(B0, 2, 3, 5), [1, 0], [0, 2]), in_dims=(0, None, None))
1812*da0073e9SAndroid Build Coastguard Worker        test(op, (torch.rand(2, 3, B0, 5), [1, 0], [0, 2]), in_dims=(1, None, None))
1813*da0073e9SAndroid Build Coastguard Worker        test(
1814*da0073e9SAndroid Build Coastguard Worker            vmap(op, in_dims=(0, None, None)),
1815*da0073e9SAndroid Build Coastguard Worker            (torch.rand(B1, 2, B0, 5), [0, 1], [1, 0]),
1816*da0073e9SAndroid Build Coastguard Worker            in_dims=(2, None, None),
1817*da0073e9SAndroid Build Coastguard Worker        )
1818*da0073e9SAndroid Build Coastguard Worker        test(
1819*da0073e9SAndroid Build Coastguard Worker            vmap(vmap(op, in_dims=(2, None, None)), in_dims=(0, None, None)),
1820*da0073e9SAndroid Build Coastguard Worker            (torch.rand(B1, 2, B0, 5, B2), [0, 1], [1, 0]),
1821*da0073e9SAndroid Build Coastguard Worker            in_dims=(2, None, None),
1822*da0073e9SAndroid Build Coastguard Worker        )
1823*da0073e9SAndroid Build Coastguard Worker
1824*da0073e9SAndroid Build Coastguard Worker    def test_mm(self):
1825*da0073e9SAndroid Build Coastguard Worker        op = torch.mm
1826*da0073e9SAndroid Build Coastguard Worker        test = self._vmap_test
1827*da0073e9SAndroid Build Coastguard Worker        B0, B1 = 7, 11
1828*da0073e9SAndroid Build Coastguard Worker
1829*da0073e9SAndroid Build Coastguard Worker        # shape mismatch
1830*da0073e9SAndroid Build Coastguard Worker        msg = "Shape mismatch"
1831*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, msg):
1832*da0073e9SAndroid Build Coastguard Worker            vmap(op)(torch.randn(B0, 2, 2, 2), torch.randn(B0, 2))
1833*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, msg):
1834*da0073e9SAndroid Build Coastguard Worker            vmap(op, in_dims=(0, None))(torch.randn(B0, 2), torch.randn(2, 2))
1835*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, msg):
1836*da0073e9SAndroid Build Coastguard Worker            vmap(op, in_dims=(None, 0))(torch.randn(2, 2), torch.randn(B0, 2, 2, 2))
1837*da0073e9SAndroid Build Coastguard Worker
1838*da0073e9SAndroid Build Coastguard Worker        # left arg is vmapped
1839*da0073e9SAndroid Build Coastguard Worker        test(op, (torch.rand(B0, 2, 5), torch.rand(5, 2)), in_dims=(0, None))
1840*da0073e9SAndroid Build Coastguard Worker        test(
1841*da0073e9SAndroid Build Coastguard Worker            vmap(op, in_dims=(0, None)),
1842*da0073e9SAndroid Build Coastguard Worker            (torch.rand(B1, B0, 2, 5), torch.rand(5, 2)),
1843*da0073e9SAndroid Build Coastguard Worker            in_dims=(1, None),
1844*da0073e9SAndroid Build Coastguard Worker        )
1845*da0073e9SAndroid Build Coastguard Worker
1846*da0073e9SAndroid Build Coastguard Worker        # right arg is vmapped
1847*da0073e9SAndroid Build Coastguard Worker        test(op, (torch.rand(2, 5), torch.rand(B0, 5, 2)), in_dims=(None, 0))
1848*da0073e9SAndroid Build Coastguard Worker        test(
1849*da0073e9SAndroid Build Coastguard Worker            vmap(op, in_dims=(None, 0)),
1850*da0073e9SAndroid Build Coastguard Worker            (torch.rand(2, 5), torch.rand(B1, B0, 5, 2)),
1851*da0073e9SAndroid Build Coastguard Worker            in_dims=(None, 1),
1852*da0073e9SAndroid Build Coastguard Worker        )
1853*da0073e9SAndroid Build Coastguard Worker
1854*da0073e9SAndroid Build Coastguard Worker        # both args are vmapped
1855*da0073e9SAndroid Build Coastguard Worker        test(op, (torch.rand(B0, 2, 5), torch.rand(B0, 5, 2)))
1856*da0073e9SAndroid Build Coastguard Worker        test(
1857*da0073e9SAndroid Build Coastguard Worker            vmap(op),
1858*da0073e9SAndroid Build Coastguard Worker            (torch.rand(B1, B0, 2, 5), torch.rand(B0, B1, 5, 2)),
1859*da0073e9SAndroid Build Coastguard Worker            in_dims=(1, 0),
1860*da0073e9SAndroid Build Coastguard Worker        )
1861*da0073e9SAndroid Build Coastguard Worker        test(
1862*da0073e9SAndroid Build Coastguard Worker            vmap(op, in_dims=(0, None)),
1863*da0073e9SAndroid Build Coastguard Worker            (torch.rand(B1, 2, 5), torch.rand(B0, 5, 2)),
1864*da0073e9SAndroid Build Coastguard Worker            in_dims=(None, 0),
1865*da0073e9SAndroid Build Coastguard Worker        )
1866*da0073e9SAndroid Build Coastguard Worker
1867*da0073e9SAndroid Build Coastguard Worker    def test_mv(self):
1868*da0073e9SAndroid Build Coastguard Worker        op = torch.mv
1869*da0073e9SAndroid Build Coastguard Worker        test = self._vmap_test
1870*da0073e9SAndroid Build Coastguard Worker        B0, B1 = 7, 11
1871*da0073e9SAndroid Build Coastguard Worker
1872*da0073e9SAndroid Build Coastguard Worker        # shape mismatch
1873*da0073e9SAndroid Build Coastguard Worker        msg = "Shape mismatch"
1874*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, msg):
1875*da0073e9SAndroid Build Coastguard Worker            vmap(op)(torch.randn(B0, 2, 2, 2), torch.randn(B0, 2))
1876*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, msg):
1877*da0073e9SAndroid Build Coastguard Worker            vmap(op, in_dims=(0, None))(torch.randn(B0, 2, 2), torch.randn(2, 2))
1878*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, msg):
1879*da0073e9SAndroid Build Coastguard Worker            vmap(op, in_dims=(None, 0))(torch.randn(2, 2), torch.randn(B0, 2, 2))
1880*da0073e9SAndroid Build Coastguard Worker
1881*da0073e9SAndroid Build Coastguard Worker        # left arg is vmapped
1882*da0073e9SAndroid Build Coastguard Worker        test(op, (torch.rand(B0, 2, 5), torch.rand(5)), in_dims=(0, None))
1883*da0073e9SAndroid Build Coastguard Worker        test(
1884*da0073e9SAndroid Build Coastguard Worker            vmap(op, in_dims=(0, None)),
1885*da0073e9SAndroid Build Coastguard Worker            (torch.rand(B1, B0, 2, 5), torch.rand(5)),
1886*da0073e9SAndroid Build Coastguard Worker            in_dims=(1, None),
1887*da0073e9SAndroid Build Coastguard Worker        )
1888*da0073e9SAndroid Build Coastguard Worker
1889*da0073e9SAndroid Build Coastguard Worker        # right arg is vmapped
1890*da0073e9SAndroid Build Coastguard Worker        test(op, (torch.rand(2, 5), torch.rand(B0, 5)), in_dims=(None, 0))
1891*da0073e9SAndroid Build Coastguard Worker        test(
1892*da0073e9SAndroid Build Coastguard Worker            vmap(op, in_dims=(None, 0)),
1893*da0073e9SAndroid Build Coastguard Worker            (torch.rand(2, 5), torch.rand(B1, B0, 5)),
1894*da0073e9SAndroid Build Coastguard Worker            in_dims=(None, 1),
1895*da0073e9SAndroid Build Coastguard Worker        )
1896*da0073e9SAndroid Build Coastguard Worker
1897*da0073e9SAndroid Build Coastguard Worker        # both args are vmapped
1898*da0073e9SAndroid Build Coastguard Worker        test(op, (torch.rand(B0, 2, 5), torch.rand(B0, 5)))
1899*da0073e9SAndroid Build Coastguard Worker        test(
1900*da0073e9SAndroid Build Coastguard Worker            vmap(op), (torch.rand(B1, B0, 2, 5), torch.rand(B0, B1, 5)), in_dims=(1, 0)
1901*da0073e9SAndroid Build Coastguard Worker        )
1902*da0073e9SAndroid Build Coastguard Worker        test(
1903*da0073e9SAndroid Build Coastguard Worker            vmap(op, in_dims=(0, None)),
1904*da0073e9SAndroid Build Coastguard Worker            (torch.rand(B1, 2, 5), torch.rand(B0, 5)),
1905*da0073e9SAndroid Build Coastguard Worker            in_dims=(None, 0),
1906*da0073e9SAndroid Build Coastguard Worker        )
1907*da0073e9SAndroid Build Coastguard Worker
1908*da0073e9SAndroid Build Coastguard Worker    def test_narrow(self):
1909*da0073e9SAndroid Build Coastguard Worker        op = torch.narrow
1910*da0073e9SAndroid Build Coastguard Worker        test = self._vmap_view_test
1911*da0073e9SAndroid Build Coastguard Worker        B0, B1, B2 = 7, 11, 13
1912*da0073e9SAndroid Build Coastguard Worker
1913*da0073e9SAndroid Build Coastguard Worker        test(op, (torch.rand(B0, 2, 5), -1, 1, 3), in_dims=(0, None, None, None))
1914*da0073e9SAndroid Build Coastguard Worker        test(op, (torch.rand(2, B0, 5), 1, 1, 3), in_dims=(1, None, None, None))
1915*da0073e9SAndroid Build Coastguard Worker        test(
1916*da0073e9SAndroid Build Coastguard Worker            vmap(op, in_dims=(0, None, None, None)),
1917*da0073e9SAndroid Build Coastguard Worker            (torch.rand(B1, 2, B0, 5), 1, 0, 0),
1918*da0073e9SAndroid Build Coastguard Worker            in_dims=(2, None, None, None),
1919*da0073e9SAndroid Build Coastguard Worker        )
1920*da0073e9SAndroid Build Coastguard Worker        test(
1921*da0073e9SAndroid Build Coastguard Worker            vmap(
1922*da0073e9SAndroid Build Coastguard Worker                vmap(op, in_dims=(2, None, None, None)), in_dims=(0, None, None, None)
1923*da0073e9SAndroid Build Coastguard Worker            ),
1924*da0073e9SAndroid Build Coastguard Worker            (torch.rand(B1, 2, B0, 5, B2), -1, 2, 3),
1925*da0073e9SAndroid Build Coastguard Worker            in_dims=(2, None, None, None),
1926*da0073e9SAndroid Build Coastguard Worker        )
1927*da0073e9SAndroid Build Coastguard Worker
1928*da0073e9SAndroid Build Coastguard Worker    def test_new_empty(self):
1929*da0073e9SAndroid Build Coastguard Worker        # Empty is non-deterministic so we just check that the shape of the
1930*da0073e9SAndroid Build Coastguard Worker        # output tensor is what we expect and that the vmap fallback isn't used.
1931*da0073e9SAndroid Build Coastguard Worker        op = Tensor.new_empty
1932*da0073e9SAndroid Build Coastguard Worker
1933*da0073e9SAndroid Build Coastguard Worker        B0, B1 = 7, 11
1934*da0073e9SAndroid Build Coastguard Worker
1935*da0073e9SAndroid Build Coastguard Worker        result = vmap(lambda x: op(x, [2, 3]))(torch.randn(B0))
1936*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result.shape, [B0, 2, 3])
1937*da0073e9SAndroid Build Coastguard Worker
1938*da0073e9SAndroid Build Coastguard Worker        result = vmap(lambda x: op(x, []))(torch.randn(B0))
1939*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result.shape, [B0])
1940*da0073e9SAndroid Build Coastguard Worker
1941*da0073e9SAndroid Build Coastguard Worker        result = vmap(vmap(lambda x: op(x, [2, 3])))(torch.randn(B0, B1))
1942*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result.shape, [B0, B1, 2, 3])
1943*da0073e9SAndroid Build Coastguard Worker
1944*da0073e9SAndroid Build Coastguard Worker    def test_new_empty_strided(self):
1945*da0073e9SAndroid Build Coastguard Worker        # Empty is non-deterministic so we just check that the size and shape
1946*da0073e9SAndroid Build Coastguard Worker        # of the output are what we expect and that the vmap fallback isn't used
1947*da0073e9SAndroid Build Coastguard Worker        B0, B1 = 7, 11
1948*da0073e9SAndroid Build Coastguard Worker
1949*da0073e9SAndroid Build Coastguard Worker        def _test_single_vmap(size, stride, B0):
1950*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(B0)
1951*da0073e9SAndroid Build Coastguard Worker            result = vmap(lambda x: x.new_empty_strided(size, stride))(x)
1952*da0073e9SAndroid Build Coastguard Worker            S = torch.empty_strided(size, stride).storage().size()
1953*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result.shape, [B0] + size)
1954*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result.stride(), [S] + stride)
1955*da0073e9SAndroid Build Coastguard Worker
1956*da0073e9SAndroid Build Coastguard Worker        def _test_double_vmap(size, stride, B0, B1):
1957*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(B0, B1)
1958*da0073e9SAndroid Build Coastguard Worker            result = vmap(vmap(lambda x: x.new_empty_strided(size, stride)))(x)
1959*da0073e9SAndroid Build Coastguard Worker            S = torch.empty_strided(size, stride).storage().size()
1960*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result.shape, [B0, B1] + size)
1961*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result.stride(), [B1 * S, S] + stride)
1962*da0073e9SAndroid Build Coastguard Worker
1963*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(B1, B0)
1964*da0073e9SAndroid Build Coastguard Worker            result = vmap(vmap(lambda x: x.new_empty_strided(size, stride)), in_dims=1)(
1965*da0073e9SAndroid Build Coastguard Worker                x
1966*da0073e9SAndroid Build Coastguard Worker            )
1967*da0073e9SAndroid Build Coastguard Worker            S = x.new_empty_strided(size, stride).storage().size()
1968*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result.shape, [B0, B1] + size)
1969*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result.stride(), [B1 * S, S] + stride)
1970*da0073e9SAndroid Build Coastguard Worker
1971*da0073e9SAndroid Build Coastguard Worker        # contiguous case
1972*da0073e9SAndroid Build Coastguard Worker        _test_single_vmap([2, 3, 5], [3 * 5, 5, 1], B0)
1973*da0073e9SAndroid Build Coastguard Worker        _test_double_vmap([2, 3, 5], [3 * 5, 5, 1], B0, B1)
1974*da0073e9SAndroid Build Coastguard Worker
1975*da0073e9SAndroid Build Coastguard Worker        # expanded
1976*da0073e9SAndroid Build Coastguard Worker        _test_single_vmap([2, 3, 5], [0, 5, 1], B0)
1977*da0073e9SAndroid Build Coastguard Worker        _test_double_vmap([2, 3, 5], [0, 5, 1], B0, B1)
1978*da0073e9SAndroid Build Coastguard Worker
1979*da0073e9SAndroid Build Coastguard Worker        # some of these cases are pretty strange, just verifying that if
1980*da0073e9SAndroid Build Coastguard Worker        # empty_strided allows them then BatchedTensor.new_empty_strided
1981*da0073e9SAndroid Build Coastguard Worker        # can as well
1982*da0073e9SAndroid Build Coastguard Worker        for shape in [[2, 3, 4], [0, 2, 0]]:
1983*da0073e9SAndroid Build Coastguard Worker            for strides in [[12, 4, 1], [2, 4, 6], [0, 0, 0]]:
1984*da0073e9SAndroid Build Coastguard Worker                _test_single_vmap(shape, strides, B0)
1985*da0073e9SAndroid Build Coastguard Worker                _test_double_vmap(shape, strides, B0, B1)
1986*da0073e9SAndroid Build Coastguard Worker
1987*da0073e9SAndroid Build Coastguard Worker    def test_new_zeros(self):
1988*da0073e9SAndroid Build Coastguard Worker        op = Tensor.new_zeros
1989*da0073e9SAndroid Build Coastguard Worker        test = functools.partial(self._vmap_test, check_propagates_grad=False)
1990*da0073e9SAndroid Build Coastguard Worker        B0, B1 = 7, 11
1991*da0073e9SAndroid Build Coastguard Worker
1992*da0073e9SAndroid Build Coastguard Worker        test(lambda x: op(x, 2, 3), (torch.rand(B0),))
1993*da0073e9SAndroid Build Coastguard Worker        test(lambda x: op(x, []), (torch.rand(B0),))
1994*da0073e9SAndroid Build Coastguard Worker        test(vmap(lambda x: op(x, 3, 5)), (torch.rand(B0, B1),))
1995*da0073e9SAndroid Build Coastguard Worker
1996*da0073e9SAndroid Build Coastguard Worker    def test_select(self):
1997*da0073e9SAndroid Build Coastguard Worker        op = torch.select
1998*da0073e9SAndroid Build Coastguard Worker        test = self._vmap_view_test
1999*da0073e9SAndroid Build Coastguard Worker        B0, B1, B2 = 7, 11, 13
2000*da0073e9SAndroid Build Coastguard Worker        test(op, (torch.rand(B0, 2, 5), 0, 0), in_dims=(0, None, None))
2001*da0073e9SAndroid Build Coastguard Worker        test(op, (torch.rand(2, B0, 5), 1, 1), in_dims=(1, None, None))
2002*da0073e9SAndroid Build Coastguard Worker        test(vmap(lambda t: op(t, 1, 1)), (torch.rand(B1, 2, B0, 5),), in_dims=2)
2003*da0073e9SAndroid Build Coastguard Worker        test(
2004*da0073e9SAndroid Build Coastguard Worker            vmap(vmap(lambda t: op(t, 1, 1), in_dims=1)),
2005*da0073e9SAndroid Build Coastguard Worker            (torch.rand(B1, 2, B0, B2, 5),),
2006*da0073e9SAndroid Build Coastguard Worker            in_dims=2,
2007*da0073e9SAndroid Build Coastguard Worker        )
2008*da0073e9SAndroid Build Coastguard Worker
2009*da0073e9SAndroid Build Coastguard Worker    def test_stack(self):
2010*da0073e9SAndroid Build Coastguard Worker        test = self._vmap_test
2011*da0073e9SAndroid Build Coastguard Worker        B0, B1 = 5, 7
2012*da0073e9SAndroid Build Coastguard Worker
2013*da0073e9SAndroid Build Coastguard Worker        # Quick hack b/c vmap can't accept a list of tensors as an argument
2014*da0073e9SAndroid Build Coastguard Worker        def get_op(dim):
2015*da0073e9SAndroid Build Coastguard Worker            def op(*tensors):
2016*da0073e9SAndroid Build Coastguard Worker                return torch.stack(tensors, dim=dim)
2017*da0073e9SAndroid Build Coastguard Worker
2018*da0073e9SAndroid Build Coastguard Worker            return op
2019*da0073e9SAndroid Build Coastguard Worker
2020*da0073e9SAndroid Build Coastguard Worker        test(get_op(0), (torch.rand(B0, 3), torch.rand(B0, 3)))
2021*da0073e9SAndroid Build Coastguard Worker        test(get_op(0), (torch.rand(3), torch.rand(B0, 3)), in_dims=(None, 0))
2022*da0073e9SAndroid Build Coastguard Worker        test(get_op(0), (torch.rand(2, 17), torch.rand(2, 17, B0)), in_dims=(None, 2))
2023*da0073e9SAndroid Build Coastguard Worker        test(get_op(-1), (torch.rand(2, 17), torch.rand(2, 17, B0)), in_dims=(None, 2))
2024*da0073e9SAndroid Build Coastguard Worker        test(
2025*da0073e9SAndroid Build Coastguard Worker            vmap(get_op(0), in_dims=(0, None)),
2026*da0073e9SAndroid Build Coastguard Worker            (torch.rand(B1, 2), torch.rand(B0, 2)),
2027*da0073e9SAndroid Build Coastguard Worker            in_dims=(None, 0),
2028*da0073e9SAndroid Build Coastguard Worker        )
2029*da0073e9SAndroid Build Coastguard Worker        test(
2030*da0073e9SAndroid Build Coastguard Worker            vmap(get_op(0), in_dims=(0, 0)),
2031*da0073e9SAndroid Build Coastguard Worker            (torch.rand(B1, 2), torch.rand(B0, B1, 2)),
2032*da0073e9SAndroid Build Coastguard Worker            in_dims=(None, 0),
2033*da0073e9SAndroid Build Coastguard Worker        )
2034*da0073e9SAndroid Build Coastguard Worker
2035*da0073e9SAndroid Build Coastguard Worker    def test_slice(self):
2036*da0073e9SAndroid Build Coastguard Worker        test = self._vmap_view_test
2037*da0073e9SAndroid Build Coastguard Worker        B0, B1, B2 = 7, 11, 13
2038*da0073e9SAndroid Build Coastguard Worker        test(lambda t: t[0:1], (torch.rand(B0, 3, 5),))
2039*da0073e9SAndroid Build Coastguard Worker        test(lambda t: t[:, 1:3], (torch.rand(3, 5, B0),), in_dims=2)
2040*da0073e9SAndroid Build Coastguard Worker        test(
2041*da0073e9SAndroid Build Coastguard Worker            vmap(lambda t: t[:, 0:1], in_dims=2), (torch.rand(3, 5, B0, B1),), in_dims=2
2042*da0073e9SAndroid Build Coastguard Worker        )
2043*da0073e9SAndroid Build Coastguard Worker        test(
2044*da0073e9SAndroid Build Coastguard Worker            vmap(vmap(lambda t: t[0:1], in_dims=2), in_dims=2),
2045*da0073e9SAndroid Build Coastguard Worker            (torch.rand(3, 5, B0, B1, B2),),
2046*da0073e9SAndroid Build Coastguard Worker            in_dims=2,
2047*da0073e9SAndroid Build Coastguard Worker        )
2048*da0073e9SAndroid Build Coastguard Worker
2049*da0073e9SAndroid Build Coastguard Worker    def test_squeeze(self):
2050*da0073e9SAndroid Build Coastguard Worker        test = self._vmap_view_test
2051*da0073e9SAndroid Build Coastguard Worker        op = torch.squeeze
2052*da0073e9SAndroid Build Coastguard Worker        B0, B1 = 1, 11
2053*da0073e9SAndroid Build Coastguard Worker        test(op, (torch.rand(B0),))
2054*da0073e9SAndroid Build Coastguard Worker        test(op, (torch.rand(B0, 3, 5),))
2055*da0073e9SAndroid Build Coastguard Worker        test(op, (torch.rand(1, B0, 5),), in_dims=1)
2056*da0073e9SAndroid Build Coastguard Worker        test(op, (torch.rand(B0, 0, 1, 5, 1),))
2057*da0073e9SAndroid Build Coastguard Worker        test(op, (torch.rand(B0, 1, 1, 1, 1),))
2058*da0073e9SAndroid Build Coastguard Worker        test(vmap(op), (torch.rand(B0, B1, 1),))
2059*da0073e9SAndroid Build Coastguard Worker        test(vmap(op), (torch.rand(B1, 1, B0),), in_dims=2)
2060*da0073e9SAndroid Build Coastguard Worker
2061*da0073e9SAndroid Build Coastguard Worker    def test_sum_dim(self):
2062*da0073e9SAndroid Build Coastguard Worker        test = self._vmap_test
2063*da0073e9SAndroid Build Coastguard Worker        B0, B1 = 5, 7
2064*da0073e9SAndroid Build Coastguard Worker
2065*da0073e9SAndroid Build Coastguard Worker        # Single vmap, various in_dims / out_dims
2066*da0073e9SAndroid Build Coastguard Worker        test(lambda x: x.sum(()), [torch.randn([B0])])
2067*da0073e9SAndroid Build Coastguard Worker        test(lambda x: x.sum(()), [torch.randn([B0, 2])])
2068*da0073e9SAndroid Build Coastguard Worker        test(lambda x: x.sum(0), [torch.randn([B0])])
2069*da0073e9SAndroid Build Coastguard Worker        test(lambda x: x.sum(-1), [torch.randn([B0])])
2070*da0073e9SAndroid Build Coastguard Worker        test(lambda x: x.sum(0), [torch.randn([B0, 3])])
2071*da0073e9SAndroid Build Coastguard Worker        test(lambda x: x.sum(-1), [torch.randn([2, 5, B0, 3])], in_dims=2)
2072*da0073e9SAndroid Build Coastguard Worker        test(lambda x: x.sum(2), [torch.randn([2, 5, B0, 3])], in_dims=2, out_dims=2)
2073*da0073e9SAndroid Build Coastguard Worker
2074*da0073e9SAndroid Build Coastguard Worker        # Doubly nested vmap
2075*da0073e9SAndroid Build Coastguard Worker        test(vmap(lambda x: x.sum(())), [torch.randn([B0, B1])])
2076*da0073e9SAndroid Build Coastguard Worker        test(vmap(lambda x: x.sum(0)), [torch.randn([B0, B1])])
2077*da0073e9SAndroid Build Coastguard Worker        test(vmap(lambda x: x.sum(-1)), [torch.randn([B0, B1])])
2078*da0073e9SAndroid Build Coastguard Worker        test(vmap(lambda x: x.sum(-2)), [torch.randn([B1, 2, 5, B0, 3])], in_dims=2)
2079*da0073e9SAndroid Build Coastguard Worker        test(
2080*da0073e9SAndroid Build Coastguard Worker            vmap(lambda x: x.sum(2), in_dims=2),
2081*da0073e9SAndroid Build Coastguard Worker            [torch.randn([2, 5, B0, B1, 3])],
2082*da0073e9SAndroid Build Coastguard Worker            in_dims=2,
2083*da0073e9SAndroid Build Coastguard Worker            out_dims=2,
2084*da0073e9SAndroid Build Coastguard Worker        )
2085*da0073e9SAndroid Build Coastguard Worker
2086*da0073e9SAndroid Build Coastguard Worker    def test_reshape(self):
2087*da0073e9SAndroid Build Coastguard Worker        test = self._vmap_test
2088*da0073e9SAndroid Build Coastguard Worker        B0, B1, B2 = 7, 11, 13
2089*da0073e9SAndroid Build Coastguard Worker        op = torch.reshape
2090*da0073e9SAndroid Build Coastguard Worker        test(op, (torch.rand(B0, 2 * 5), [2, 5]), in_dims=(0, None), check_view=True)
2091*da0073e9SAndroid Build Coastguard Worker        test(
2092*da0073e9SAndroid Build Coastguard Worker            op, (torch.rand(2, B0, 5), [1, 1, 10]), in_dims=(1, None), check_view=False
2093*da0073e9SAndroid Build Coastguard Worker        )
2094*da0073e9SAndroid Build Coastguard Worker        test(
2095*da0073e9SAndroid Build Coastguard Worker            vmap(lambda t: t.reshape([-1])),
2096*da0073e9SAndroid Build Coastguard Worker            (torch.rand(B0, B1, 2, 5),),
2097*da0073e9SAndroid Build Coastguard Worker            check_view=True,
2098*da0073e9SAndroid Build Coastguard Worker        )
2099*da0073e9SAndroid Build Coastguard Worker        test(
2100*da0073e9SAndroid Build Coastguard Worker            vmap(vmap(lambda t: t.reshape([-1]), in_dims=2), in_dims=1),
2101*da0073e9SAndroid Build Coastguard Worker            (torch.rand(3, B1, 2, B2, 5, B0),),
2102*da0073e9SAndroid Build Coastguard Worker            in_dims=5,
2103*da0073e9SAndroid Build Coastguard Worker            check_view=False,
2104*da0073e9SAndroid Build Coastguard Worker        )
2105*da0073e9SAndroid Build Coastguard Worker
2106*da0073e9SAndroid Build Coastguard Worker    def test_reshape_as(self):
2107*da0073e9SAndroid Build Coastguard Worker        test = self._vmap_test
2108*da0073e9SAndroid Build Coastguard Worker        B0, B1, B2 = 7, 11, 13
2109*da0073e9SAndroid Build Coastguard Worker        op = torch.Tensor.reshape_as
2110*da0073e9SAndroid Build Coastguard Worker        test(op, (torch.rand(B0, 2 * 5), torch.rand(B0, 2, 5)), check_view=True)
2111*da0073e9SAndroid Build Coastguard Worker        test(
2112*da0073e9SAndroid Build Coastguard Worker            op,
2113*da0073e9SAndroid Build Coastguard Worker            (torch.rand(2 * 5), torch.rand(B0, 2, 5)),
2114*da0073e9SAndroid Build Coastguard Worker            in_dims=(None, 0),
2115*da0073e9SAndroid Build Coastguard Worker            check_view=True,
2116*da0073e9SAndroid Build Coastguard Worker        )
2117*da0073e9SAndroid Build Coastguard Worker        test(
2118*da0073e9SAndroid Build Coastguard Worker            op,
2119*da0073e9SAndroid Build Coastguard Worker            (torch.rand(B0, 2 * 5), torch.rand(2, 5)),
2120*da0073e9SAndroid Build Coastguard Worker            in_dims=(0, None),
2121*da0073e9SAndroid Build Coastguard Worker            check_view=True,
2122*da0073e9SAndroid Build Coastguard Worker        )
2123*da0073e9SAndroid Build Coastguard Worker
2124*da0073e9SAndroid Build Coastguard Worker        test(
2125*da0073e9SAndroid Build Coastguard Worker            op,
2126*da0073e9SAndroid Build Coastguard Worker            (torch.rand(2, B0, 5), torch.rand(1, 1, 10)),
2127*da0073e9SAndroid Build Coastguard Worker            in_dims=(1, None),
2128*da0073e9SAndroid Build Coastguard Worker            check_view=False,
2129*da0073e9SAndroid Build Coastguard Worker        )
2130*da0073e9SAndroid Build Coastguard Worker
2131*da0073e9SAndroid Build Coastguard Worker        test(
2132*da0073e9SAndroid Build Coastguard Worker            vmap(op),
2133*da0073e9SAndroid Build Coastguard Worker            (torch.rand(B0, B1, 2, 5), torch.randn(B0, B1, 10)),
2134*da0073e9SAndroid Build Coastguard Worker            check_view=True,
2135*da0073e9SAndroid Build Coastguard Worker        )
2136*da0073e9SAndroid Build Coastguard Worker        test(
2137*da0073e9SAndroid Build Coastguard Worker            vmap(vmap(op, in_dims=(2, None)), in_dims=(1, None)),
2138*da0073e9SAndroid Build Coastguard Worker            (torch.rand(3, B1, 2, B2, 5, B0), torch.rand(B0, 3 * 2 * 5)),
2139*da0073e9SAndroid Build Coastguard Worker            in_dims=(5, 0),
2140*da0073e9SAndroid Build Coastguard Worker            check_view=False,
2141*da0073e9SAndroid Build Coastguard Worker        )
2142*da0073e9SAndroid Build Coastguard Worker
2143*da0073e9SAndroid Build Coastguard Worker    def test_result_type(self):
2144*da0073e9SAndroid Build Coastguard Worker        def scalar_tensor_with_dtype(op):
2145*da0073e9SAndroid Build Coastguard Worker            def wrapped(*args, **kwargs):
2146*da0073e9SAndroid Build Coastguard Worker                dtype = op(*args, **kwargs)
2147*da0073e9SAndroid Build Coastguard Worker                return torch.ones([], dtype=dtype)
2148*da0073e9SAndroid Build Coastguard Worker
2149*da0073e9SAndroid Build Coastguard Worker            return wrapped
2150*da0073e9SAndroid Build Coastguard Worker
2151*da0073e9SAndroid Build Coastguard Worker        test = self._vmap_test
2152*da0073e9SAndroid Build Coastguard Worker        op = scalar_tensor_with_dtype(torch.result_type)
2153*da0073e9SAndroid Build Coastguard Worker
2154*da0073e9SAndroid Build Coastguard Worker        B0 = 2
2155*da0073e9SAndroid Build Coastguard Worker
2156*da0073e9SAndroid Build Coastguard Worker        test(
2157*da0073e9SAndroid Build Coastguard Worker            op,
2158*da0073e9SAndroid Build Coastguard Worker            (torch.randn(B0), torch.randn(B0, dtype=torch.float64)),
2159*da0073e9SAndroid Build Coastguard Worker            check_propagates_grad=False,
2160*da0073e9SAndroid Build Coastguard Worker        )
2161*da0073e9SAndroid Build Coastguard Worker        test(
2162*da0073e9SAndroid Build Coastguard Worker            op,
2163*da0073e9SAndroid Build Coastguard Worker            (torch.randn(B0), torch.randint(10, [B0], dtype=torch.int64)),
2164*da0073e9SAndroid Build Coastguard Worker            check_propagates_grad=False,
2165*da0073e9SAndroid Build Coastguard Worker        )
2166*da0073e9SAndroid Build Coastguard Worker
2167*da0073e9SAndroid Build Coastguard Worker        test(lambda x: op(x, 1), (torch.randn(B0),), check_propagates_grad=False)
2168*da0073e9SAndroid Build Coastguard Worker        test(lambda x: op(x, 1.6), (torch.randn(B0),), check_propagates_grad=False)
2169*da0073e9SAndroid Build Coastguard Worker
2170*da0073e9SAndroid Build Coastguard Worker        test(
2171*da0073e9SAndroid Build Coastguard Worker            lambda x: op(x, torch.tensor(1)),
2172*da0073e9SAndroid Build Coastguard Worker            (torch.randn(B0),),
2173*da0073e9SAndroid Build Coastguard Worker            check_propagates_grad=False,
2174*da0073e9SAndroid Build Coastguard Worker        )
2175*da0073e9SAndroid Build Coastguard Worker        test(
2176*da0073e9SAndroid Build Coastguard Worker            lambda x: op(x, torch.tensor(1.6, dtype=torch.double)),
2177*da0073e9SAndroid Build Coastguard Worker            (torch.randn(B0),),
2178*da0073e9SAndroid Build Coastguard Worker            check_propagates_grad=False,
2179*da0073e9SAndroid Build Coastguard Worker        )
2180*da0073e9SAndroid Build Coastguard Worker
2181*da0073e9SAndroid Build Coastguard Worker        test(
2182*da0073e9SAndroid Build Coastguard Worker            op,
2183*da0073e9SAndroid Build Coastguard Worker            (torch.randn(B0, 2), torch.randn(B0, 2, dtype=torch.float64)),
2184*da0073e9SAndroid Build Coastguard Worker            check_propagates_grad=False,
2185*da0073e9SAndroid Build Coastguard Worker        )
2186*da0073e9SAndroid Build Coastguard Worker        test(
2187*da0073e9SAndroid Build Coastguard Worker            op,
2188*da0073e9SAndroid Build Coastguard Worker            (torch.randn(B0, 2), torch.randint(10, [B0, 2], dtype=torch.int64)),
2189*da0073e9SAndroid Build Coastguard Worker            check_propagates_grad=False,
2190*da0073e9SAndroid Build Coastguard Worker        )
2191*da0073e9SAndroid Build Coastguard Worker
2192*da0073e9SAndroid Build Coastguard Worker        test(lambda x: op(x, 1), (torch.randn(B0, 2),), check_propagates_grad=False)
2193*da0073e9SAndroid Build Coastguard Worker        test(lambda x: op(x, 1.6), (torch.randn(B0, 2),), check_propagates_grad=False)
2194*da0073e9SAndroid Build Coastguard Worker
2195*da0073e9SAndroid Build Coastguard Worker        test(
2196*da0073e9SAndroid Build Coastguard Worker            lambda x: op(x, torch.tensor(1)),
2197*da0073e9SAndroid Build Coastguard Worker            (torch.randn(B0, 2),),
2198*da0073e9SAndroid Build Coastguard Worker            check_propagates_grad=False,
2199*da0073e9SAndroid Build Coastguard Worker        )
2200*da0073e9SAndroid Build Coastguard Worker        test(
2201*da0073e9SAndroid Build Coastguard Worker            lambda x: op(x, torch.tensor(1.6, dtype=torch.double)),
2202*da0073e9SAndroid Build Coastguard Worker            (torch.randn(B0, 2),),
2203*da0073e9SAndroid Build Coastguard Worker            check_propagates_grad=False,
2204*da0073e9SAndroid Build Coastguard Worker        )
2205*da0073e9SAndroid Build Coastguard Worker
2206*da0073e9SAndroid Build Coastguard Worker        test(
2207*da0073e9SAndroid Build Coastguard Worker            op,
2208*da0073e9SAndroid Build Coastguard Worker            (torch.randn(B0, 2), torch.randn(B0, dtype=torch.float64)),
2209*da0073e9SAndroid Build Coastguard Worker            check_propagates_grad=False,
2210*da0073e9SAndroid Build Coastguard Worker        )
2211*da0073e9SAndroid Build Coastguard Worker        test(
2212*da0073e9SAndroid Build Coastguard Worker            op,
2213*da0073e9SAndroid Build Coastguard Worker            (torch.randn(B0, 2), torch.randint(10, [B0], dtype=torch.int64)),
2214*da0073e9SAndroid Build Coastguard Worker            check_propagates_grad=False,
2215*da0073e9SAndroid Build Coastguard Worker        )
2216*da0073e9SAndroid Build Coastguard Worker
2217*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("too slow")
2218*da0073e9SAndroid Build Coastguard Worker    def test_tensor_split(self):
2219*da0073e9SAndroid Build Coastguard Worker        test = self._vmap_view_test
2220*da0073e9SAndroid Build Coastguard Worker        op = torch.tensor_split
2221*da0073e9SAndroid Build Coastguard Worker        B0, B1, B2 = 7, 11, 13
2222*da0073e9SAndroid Build Coastguard Worker
2223*da0073e9SAndroid Build Coastguard Worker        # tests for torch.tensor_split(self, indices_or_sections: int, dim)
2224*da0073e9SAndroid Build Coastguard Worker        test(op, (torch.rand(B0, 2, 1024), 5, -1), in_dims=(0, None, None))
2225*da0073e9SAndroid Build Coastguard Worker        test(op, (torch.rand(2, B0, 1024), 150, 1), in_dims=(1, None, None))
2226*da0073e9SAndroid Build Coastguard Worker        test(
2227*da0073e9SAndroid Build Coastguard Worker            vmap(op, in_dims=(0, None, None)),
2228*da0073e9SAndroid Build Coastguard Worker            (torch.rand(B1, 1023, B0, 5), 256, 0),
2229*da0073e9SAndroid Build Coastguard Worker            in_dims=(2, None, None),
2230*da0073e9SAndroid Build Coastguard Worker        )
2231*da0073e9SAndroid Build Coastguard Worker        test(
2232*da0073e9SAndroid Build Coastguard Worker            vmap(vmap(lambda t: op(t, 4, 1), in_dims=2)),
2233*da0073e9SAndroid Build Coastguard Worker            (torch.rand(B1, 2, B0, 64, B2),),
2234*da0073e9SAndroid Build Coastguard Worker            in_dims=2,
2235*da0073e9SAndroid Build Coastguard Worker        )
2236*da0073e9SAndroid Build Coastguard Worker
2237*da0073e9SAndroid Build Coastguard Worker        # tests for torch.tensor_split(self, indices_or_sections: List[int], dim)
2238*da0073e9SAndroid Build Coastguard Worker        test(
2239*da0073e9SAndroid Build Coastguard Worker            op,
2240*da0073e9SAndroid Build Coastguard Worker            (torch.rand(B0, 2, 1024), [50, 100, 378, 890], -1),
2241*da0073e9SAndroid Build Coastguard Worker            in_dims=(0, None, None),
2242*da0073e9SAndroid Build Coastguard Worker        )
2243*da0073e9SAndroid Build Coastguard Worker        test(
2244*da0073e9SAndroid Build Coastguard Worker            op,
2245*da0073e9SAndroid Build Coastguard Worker            (torch.rand(2, B0, 1024), [50, 100, 212, 345, 0, 378, 890], 1),
2246*da0073e9SAndroid Build Coastguard Worker            in_dims=(1, None, None),
2247*da0073e9SAndroid Build Coastguard Worker        )
2248*da0073e9SAndroid Build Coastguard Worker        test(
2249*da0073e9SAndroid Build Coastguard Worker            vmap(op, in_dims=(0, None, None)),
2250*da0073e9SAndroid Build Coastguard Worker            (torch.rand(B1, 1023, B0, 5), [50, 100, 212, 345, 0, 378, 890], 0),
2251*da0073e9SAndroid Build Coastguard Worker            in_dims=(2, None, None),
2252*da0073e9SAndroid Build Coastguard Worker        )
2253*da0073e9SAndroid Build Coastguard Worker        test(
2254*da0073e9SAndroid Build Coastguard Worker            vmap(vmap(lambda t: op(t, [4, 8, 9, 34, 29], 1), in_dims=2)),
2255*da0073e9SAndroid Build Coastguard Worker            (torch.rand(B1, 2, B0, 64, B2),),
2256*da0073e9SAndroid Build Coastguard Worker            in_dims=2,
2257*da0073e9SAndroid Build Coastguard Worker        )
2258*da0073e9SAndroid Build Coastguard Worker
2259*da0073e9SAndroid Build Coastguard Worker    def test_split(self):
2260*da0073e9SAndroid Build Coastguard Worker        test = self._vmap_view_test
2261*da0073e9SAndroid Build Coastguard Worker        op = torch.split
2262*da0073e9SAndroid Build Coastguard Worker        B0, B1, B2 = 7, 11, 13
2263*da0073e9SAndroid Build Coastguard Worker
2264*da0073e9SAndroid Build Coastguard Worker        # tests for torch.split(self, split_size: int, dim)
2265*da0073e9SAndroid Build Coastguard Worker        test(op, (torch.rand(B0, 2, 1024), 101, -1), in_dims=(0, None, None))
2266*da0073e9SAndroid Build Coastguard Worker        test(op, (torch.rand(2, B0, 1024), 130, 1), in_dims=(1, None, None))
2267*da0073e9SAndroid Build Coastguard Worker        test(
2268*da0073e9SAndroid Build Coastguard Worker            vmap(op, in_dims=(0, None, None)),
2269*da0073e9SAndroid Build Coastguard Worker            (torch.rand(B1, 1023, B0, 5), 256, 0),
2270*da0073e9SAndroid Build Coastguard Worker            in_dims=(2, None, None),
2271*da0073e9SAndroid Build Coastguard Worker        )
2272*da0073e9SAndroid Build Coastguard Worker        test(
2273*da0073e9SAndroid Build Coastguard Worker            vmap(vmap(lambda t: op(t, 4, 1), in_dims=2)),
2274*da0073e9SAndroid Build Coastguard Worker            (torch.rand(B1, 2, B0, 64, B2),),
2275*da0073e9SAndroid Build Coastguard Worker            in_dims=2,
2276*da0073e9SAndroid Build Coastguard Worker        )
2277*da0073e9SAndroid Build Coastguard Worker
2278*da0073e9SAndroid Build Coastguard Worker        # tests for torch.split(self, split_size: List[int], dim)
2279*da0073e9SAndroid Build Coastguard Worker        test(op, (torch.rand(B0, 2, 1024), [1, 1020, 3], -1), in_dims=(0, None, None))
2280*da0073e9SAndroid Build Coastguard Worker        test(
2281*da0073e9SAndroid Build Coastguard Worker            op, (torch.rand(2, B0, 1024), [100] * 10 + [24], 1), in_dims=(1, None, None)
2282*da0073e9SAndroid Build Coastguard Worker        )
2283*da0073e9SAndroid Build Coastguard Worker        test(
2284*da0073e9SAndroid Build Coastguard Worker            vmap(op, in_dims=(0, None, None)),
2285*da0073e9SAndroid Build Coastguard Worker            (torch.rand(B1, 1023, B0, 5), [256] * 3 + [255], 0),
2286*da0073e9SAndroid Build Coastguard Worker            in_dims=(2, None, None),
2287*da0073e9SAndroid Build Coastguard Worker        )
2288*da0073e9SAndroid Build Coastguard Worker        test(
2289*da0073e9SAndroid Build Coastguard Worker            vmap(vmap(lambda t: op(t, [4] * 8 + [8] * 4, 1), in_dims=2)),
2290*da0073e9SAndroid Build Coastguard Worker            (torch.rand(B1, 2, B0, 64, B2),),
2291*da0073e9SAndroid Build Coastguard Worker            in_dims=2,
2292*da0073e9SAndroid Build Coastguard Worker        )
2293*da0073e9SAndroid Build Coastguard Worker
2294*da0073e9SAndroid Build Coastguard Worker    def test_trace(self):
2295*da0073e9SAndroid Build Coastguard Worker        op = torch.trace
2296*da0073e9SAndroid Build Coastguard Worker        test = self._vmap_test
2297*da0073e9SAndroid Build Coastguard Worker        B0, B1, B2 = 7, 11, 13
2298*da0073e9SAndroid Build Coastguard Worker
2299*da0073e9SAndroid Build Coastguard Worker        test(op, (torch.rand(B0, 2, 5),))
2300*da0073e9SAndroid Build Coastguard Worker        test(op, (torch.rand(2, B0, 5),), in_dims=1)
2301*da0073e9SAndroid Build Coastguard Worker        test(vmap(op), (torch.rand(B1, 2, B0, 5),), in_dims=2)
2302*da0073e9SAndroid Build Coastguard Worker        test(vmap(vmap(op, in_dims=2)), (torch.rand(B1, 2, B0, 5, B2),), in_dims=2)
2303*da0073e9SAndroid Build Coastguard Worker
2304*da0073e9SAndroid Build Coastguard Worker    def test_transpose(self):
2305*da0073e9SAndroid Build Coastguard Worker        op = torch.transpose
2306*da0073e9SAndroid Build Coastguard Worker        test = self._vmap_view_test
2307*da0073e9SAndroid Build Coastguard Worker
2308*da0073e9SAndroid Build Coastguard Worker        B0, B1, B2 = 7, 11, 13
2309*da0073e9SAndroid Build Coastguard Worker        test(lambda x: op(x, 0, 1), (torch.rand(B0, 2, 5),))
2310*da0073e9SAndroid Build Coastguard Worker        test(lambda x: op(x, -1, -2), (torch.rand(B0, 2, 5),))
2311*da0073e9SAndroid Build Coastguard Worker        test(lambda x: op(x, 3, 1), (torch.rand(B0, 2, 5, 4, 6),))
2312*da0073e9SAndroid Build Coastguard Worker        test(lambda x: op(x, 1, 0), (torch.rand(2, B0, 5),), in_dims=1)
2313*da0073e9SAndroid Build Coastguard Worker        test(vmap(lambda x: op(x, 0, 1)), (torch.rand(B1, 2, B0, 5),), in_dims=2)
2314*da0073e9SAndroid Build Coastguard Worker        test(
2315*da0073e9SAndroid Build Coastguard Worker            vmap(vmap(lambda x: op(x, 0, 1), in_dims=2)),
2316*da0073e9SAndroid Build Coastguard Worker            (torch.rand(B1, 2, B0, 5, B2),),
2317*da0073e9SAndroid Build Coastguard Worker            in_dims=2,
2318*da0073e9SAndroid Build Coastguard Worker        )
2319*da0073e9SAndroid Build Coastguard Worker
2320*da0073e9SAndroid Build Coastguard Worker        # Special case: scalar tensor
2321*da0073e9SAndroid Build Coastguard Worker        for dim1, dim2 in itertools.product([0, -1], [0, -1]):
2322*da0073e9SAndroid Build Coastguard Worker            x = torch.rand(B0)
2323*da0073e9SAndroid Build Coastguard Worker            result = vmap(lambda x: op(x, dim1, dim2))(x)
2324*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(result is x)
2325*da0073e9SAndroid Build Coastguard Worker
2326*da0073e9SAndroid Build Coastguard Worker    def test_t(self):
2327*da0073e9SAndroid Build Coastguard Worker        op = torch.t
2328*da0073e9SAndroid Build Coastguard Worker        test = self._vmap_view_test
2329*da0073e9SAndroid Build Coastguard Worker        B0, B1, B2 = 7, 11, 13
2330*da0073e9SAndroid Build Coastguard Worker        test(op, (torch.rand(B0, 2, 5),))
2331*da0073e9SAndroid Build Coastguard Worker        test(op, (torch.rand(2, B0, 5),), in_dims=1)
2332*da0073e9SAndroid Build Coastguard Worker        test(vmap(op), (torch.rand(B1, 2, B0, 5),), in_dims=2)
2333*da0073e9SAndroid Build Coastguard Worker        test(vmap(vmap(op, in_dims=2)), (torch.rand(B1, 2, B0, 5, B2),), in_dims=2)
2334*da0073e9SAndroid Build Coastguard Worker
2335*da0073e9SAndroid Build Coastguard Worker    def test_T_numpy(self):
2336*da0073e9SAndroid Build Coastguard Worker        def op(t):
2337*da0073e9SAndroid Build Coastguard Worker            return t.T
2338*da0073e9SAndroid Build Coastguard Worker
2339*da0073e9SAndroid Build Coastguard Worker        test = self._vmap_view_test
2340*da0073e9SAndroid Build Coastguard Worker        B0, B1, B2 = 7, 11, 13
2341*da0073e9SAndroid Build Coastguard Worker        test(op, (torch.rand(B0, 2, 3, 5),))
2342*da0073e9SAndroid Build Coastguard Worker        test(op, (torch.rand(2, B0, 3, 5),), in_dims=1)
2343*da0073e9SAndroid Build Coastguard Worker        test(vmap(op), (torch.rand(B1, 2, B0, 5),), in_dims=2)
2344*da0073e9SAndroid Build Coastguard Worker        test(vmap(op), (torch.rand(B1, 2, B0, 3, 5),), in_dims=2)
2345*da0073e9SAndroid Build Coastguard Worker        test(vmap(vmap(op, in_dims=2)), (torch.rand(B1, 2, B0, 3, B2, 5),), in_dims=2)
2346*da0073e9SAndroid Build Coastguard Worker
2347*da0073e9SAndroid Build Coastguard Worker    def test_to(self):
2348*da0073e9SAndroid Build Coastguard Worker        test = self._vmap_test
2349*da0073e9SAndroid Build Coastguard Worker        B0, B1 = 7, 11
2350*da0073e9SAndroid Build Coastguard Worker
2351*da0073e9SAndroid Build Coastguard Worker        test(lambda t: t.to("cpu"), (torch.rand(B0),))
2352*da0073e9SAndroid Build Coastguard Worker        test(lambda t: t.to(torch.double), (torch.rand(B0),))
2353*da0073e9SAndroid Build Coastguard Worker        test(
2354*da0073e9SAndroid Build Coastguard Worker            lambda t, o: t.to(o), (torch.rand(B0), torch.randn(B0, dtype=torch.float64))
2355*da0073e9SAndroid Build Coastguard Worker        )
2356*da0073e9SAndroid Build Coastguard Worker        test(
2357*da0073e9SAndroid Build Coastguard Worker            lambda t, o: t.to(o),
2358*da0073e9SAndroid Build Coastguard Worker            (torch.rand(B0), torch.randn(B0, dtype=torch.float64)),
2359*da0073e9SAndroid Build Coastguard Worker            in_dims=(0, None),
2360*da0073e9SAndroid Build Coastguard Worker        )
2361*da0073e9SAndroid Build Coastguard Worker        test(vmap(lambda t: t.to(torch.double)), (torch.rand(B0, B1, 3),))
2362*da0073e9SAndroid Build Coastguard Worker
2363*da0073e9SAndroid Build Coastguard Worker        # also test some casting methods
2364*da0073e9SAndroid Build Coastguard Worker        test(lambda t: t.double(), (torch.rand(B0),))
2365*da0073e9SAndroid Build Coastguard Worker        test(lambda t: t.float(), (torch.rand(B0),))
2366*da0073e9SAndroid Build Coastguard Worker        test(lambda t: t.int(), (torch.rand(B0),), check_propagates_grad=False)
2367*da0073e9SAndroid Build Coastguard Worker        test(lambda t: t.long(), (torch.rand(B0),), check_propagates_grad=False)
2368*da0073e9SAndroid Build Coastguard Worker
2369*da0073e9SAndroid Build Coastguard Worker    def test_unfold(self):
2370*da0073e9SAndroid Build Coastguard Worker        op = torch.Tensor.unfold
2371*da0073e9SAndroid Build Coastguard Worker        test = self._vmap_view_test
2372*da0073e9SAndroid Build Coastguard Worker        B0, B1, B2 = 3, 2, 5
2373*da0073e9SAndroid Build Coastguard Worker
2374*da0073e9SAndroid Build Coastguard Worker        test(op, (torch.rand(B0, 7, 11), 0, 2, 1), in_dims=(0, None, None, None))
2375*da0073e9SAndroid Build Coastguard Worker        test(op, (torch.rand(7, B0, 11), 1, 4, 2), in_dims=(1, None, None, None))
2376*da0073e9SAndroid Build Coastguard Worker        test(
2377*da0073e9SAndroid Build Coastguard Worker            vmap(op, in_dims=(0, None, None, None)),
2378*da0073e9SAndroid Build Coastguard Worker            (torch.rand(B1, 7, B0, 11), 1, 5, 1),
2379*da0073e9SAndroid Build Coastguard Worker            in_dims=(2, None, None, None),
2380*da0073e9SAndroid Build Coastguard Worker        )
2381*da0073e9SAndroid Build Coastguard Worker        test(
2382*da0073e9SAndroid Build Coastguard Worker            vmap(
2383*da0073e9SAndroid Build Coastguard Worker                vmap(op, in_dims=(2, None, None, None)), in_dims=(0, None, None, None)
2384*da0073e9SAndroid Build Coastguard Worker            ),
2385*da0073e9SAndroid Build Coastguard Worker            (torch.rand(B1, 7, B0, 11, B2), -1, 2, 4),
2386*da0073e9SAndroid Build Coastguard Worker            in_dims=(2, None, None, None),
2387*da0073e9SAndroid Build Coastguard Worker        )
2388*da0073e9SAndroid Build Coastguard Worker
2389*da0073e9SAndroid Build Coastguard Worker    def test_unbind(self):
2390*da0073e9SAndroid Build Coastguard Worker        test = self._vmap_view_test
2391*da0073e9SAndroid Build Coastguard Worker        op = torch.unbind
2392*da0073e9SAndroid Build Coastguard Worker        B0, B1, B2 = 7, 11, 13
2393*da0073e9SAndroid Build Coastguard Worker
2394*da0073e9SAndroid Build Coastguard Worker        test(op, (torch.rand(B0, 2, 1024), -1), in_dims=(0, None))
2395*da0073e9SAndroid Build Coastguard Worker        test(op, (torch.rand(B0, 2, 0),))
2396*da0073e9SAndroid Build Coastguard Worker        test(op, (torch.rand(2, B0, 7), 0), in_dims=(1, None))
2397*da0073e9SAndroid Build Coastguard Worker        test(
2398*da0073e9SAndroid Build Coastguard Worker            vmap(op, in_dims=(0, None)),
2399*da0073e9SAndroid Build Coastguard Worker            (torch.rand(B1, 1023, B0, 5), 1),
2400*da0073e9SAndroid Build Coastguard Worker            in_dims=(2, None),
2401*da0073e9SAndroid Build Coastguard Worker        )
2402*da0073e9SAndroid Build Coastguard Worker        test(
2403*da0073e9SAndroid Build Coastguard Worker            vmap(vmap(lambda t: op(t, dim=1), in_dims=2)),
2404*da0073e9SAndroid Build Coastguard Worker            (torch.rand(B1, 2, B0, 32, B2),),
2405*da0073e9SAndroid Build Coastguard Worker            in_dims=2,
2406*da0073e9SAndroid Build Coastguard Worker        )
2407*da0073e9SAndroid Build Coastguard Worker
2408*da0073e9SAndroid Build Coastguard Worker    def test_view(self):
2409*da0073e9SAndroid Build Coastguard Worker        test = self._vmap_view_test
2410*da0073e9SAndroid Build Coastguard Worker        B0, B1, B2 = 7, 11, 13
2411*da0073e9SAndroid Build Coastguard Worker        op = torch.Tensor.view
2412*da0073e9SAndroid Build Coastguard Worker
2413*da0073e9SAndroid Build Coastguard Worker        # We should error out if the view would produce an incorrect result
2414*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(RuntimeError):
2415*da0073e9SAndroid Build Coastguard Worker            vmap(op, in_dims=(1, None))(torch.rand(2, B0, 5), [10])
2416*da0073e9SAndroid Build Coastguard Worker
2417*da0073e9SAndroid Build Coastguard Worker        test(op, (torch.rand(B0, 2 * 5), [2, 5]), in_dims=(0, None))
2418*da0073e9SAndroid Build Coastguard Worker        test(op, (torch.rand(B0, 4, 5), [1, 2, 1, 10]), in_dims=(0, None))
2419*da0073e9SAndroid Build Coastguard Worker        test(vmap(lambda t: t.view([-1])), (torch.rand(B0, B1, 2, 5, 3),))
2420*da0073e9SAndroid Build Coastguard Worker        test(
2421*da0073e9SAndroid Build Coastguard Worker            vmap(vmap(lambda t: t.reshape([-1])), in_dims=1),
2422*da0073e9SAndroid Build Coastguard Worker            (torch.rand(B2, B0, B1, 3, 2, 5),),
2423*da0073e9SAndroid Build Coastguard Worker            in_dims=1,
2424*da0073e9SAndroid Build Coastguard Worker        )
2425*da0073e9SAndroid Build Coastguard Worker
2426*da0073e9SAndroid Build Coastguard Worker    def test_view_as(self):
2427*da0073e9SAndroid Build Coastguard Worker        test = self._vmap_view_test
2428*da0073e9SAndroid Build Coastguard Worker        B0, B1, B2 = 7, 11, 13
2429*da0073e9SAndroid Build Coastguard Worker        op = torch.Tensor.view_as
2430*da0073e9SAndroid Build Coastguard Worker
2431*da0073e9SAndroid Build Coastguard Worker        # We should error out if the view would produce an incorrect result
2432*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(RuntimeError):
2433*da0073e9SAndroid Build Coastguard Worker            vmap(op, in_dims=(1, 0))(torch.rand(2, B0, 5), torch.rand(B0, 10))
2434*da0073e9SAndroid Build Coastguard Worker
2435*da0073e9SAndroid Build Coastguard Worker        test(op, (torch.rand(B0, 2 * 5), torch.rand(B0, 2, 5)))
2436*da0073e9SAndroid Build Coastguard Worker        test(op, (torch.rand(2 * 5), torch.rand(B0, 2, 5)), in_dims=(None, 0))
2437*da0073e9SAndroid Build Coastguard Worker        test(op, (torch.rand(B0, 2 * 5), torch.rand(2, 5)), in_dims=(0, None))
2438*da0073e9SAndroid Build Coastguard Worker
2439*da0073e9SAndroid Build Coastguard Worker        test(op, (torch.rand(B0, 4, 5), torch.rand(2, 1, 1, 10)), in_dims=(0, None))
2440*da0073e9SAndroid Build Coastguard Worker
2441*da0073e9SAndroid Build Coastguard Worker        test(vmap(op), (torch.rand(B0, B1, 2, 5), torch.randn(B0, B1, 10)))
2442*da0073e9SAndroid Build Coastguard Worker        test(
2443*da0073e9SAndroid Build Coastguard Worker            vmap(vmap(op, in_dims=(0, None)), in_dims=(0, None)),
2444*da0073e9SAndroid Build Coastguard Worker            (torch.rand(B1, B2, B0, 3, 2, 5), torch.rand(B0, 3 * 2 * 5)),
2445*da0073e9SAndroid Build Coastguard Worker            in_dims=(2, 0),
2446*da0073e9SAndroid Build Coastguard Worker        )
2447*da0073e9SAndroid Build Coastguard Worker
2448*da0073e9SAndroid Build Coastguard Worker    def test_no_random_op_support(self):
2449*da0073e9SAndroid Build Coastguard Worker        B0 = 2
2450*da0073e9SAndroid Build Coastguard Worker
2451*da0073e9SAndroid Build Coastguard Worker        captured = torch.rand(3)
2452*da0073e9SAndroid Build Coastguard Worker
2453*da0073e9SAndroid Build Coastguard Worker        random_ops = [
2454*da0073e9SAndroid Build Coastguard Worker            # out-of-place on BatchedTensor
2455*da0073e9SAndroid Build Coastguard Worker            (torch.bernoulli, (torch.rand(B0, 1),)),
2456*da0073e9SAndroid Build Coastguard Worker            (lambda t: torch.bernoulli(t, p=0.5), (torch.rand(B0, 1),)),
2457*da0073e9SAndroid Build Coastguard Worker            (lambda t: torch.multinomial(t, 2), (torch.rand(B0, 3),)),
2458*da0073e9SAndroid Build Coastguard Worker            (torch.normal, (torch.randn(B0, 1), torch.randn(B0, 1))),
2459*da0073e9SAndroid Build Coastguard Worker            (lambda t: torch.normal(t, 1.0), (torch.randn(B0, 1),)),
2460*da0073e9SAndroid Build Coastguard Worker            (lambda t: torch.normal(0.0, t), (torch.randn(B0, 1),)),
2461*da0073e9SAndroid Build Coastguard Worker            (torch.poisson, (torch.rand(B0, 1),)),
2462*da0073e9SAndroid Build Coastguard Worker            (torch.rand_like, (torch.rand(B0, 1),)),
2463*da0073e9SAndroid Build Coastguard Worker            (torch.randn_like, (torch.rand(B0, 1),)),
2464*da0073e9SAndroid Build Coastguard Worker            (lambda t: torch.randint_like(t, 2), (torch.rand(B0, 1),)),
2465*da0073e9SAndroid Build Coastguard Worker            (lambda t: torch.randint_like(t, 0, 2), (torch.rand(B0, 1),)),
2466*da0073e9SAndroid Build Coastguard Worker            # out-of-place on captured tensor
2467*da0073e9SAndroid Build Coastguard Worker            (lambda t: torch.bernoulli(captured), (torch.rand(B0),)),
2468*da0073e9SAndroid Build Coastguard Worker            (lambda t: torch.bernoulli(captured, p=0.5), (torch.rand(B0),)),
2469*da0073e9SAndroid Build Coastguard Worker            (lambda t: torch.multinomial(captured, 2), (torch.rand(B0),)),
2470*da0073e9SAndroid Build Coastguard Worker            (lambda t: torch.normal(captured, captured), (torch.randn(B0),)),
2471*da0073e9SAndroid Build Coastguard Worker            (lambda t: torch.normal(captured, 1.0), (torch.randn(B0),)),
2472*da0073e9SAndroid Build Coastguard Worker            (lambda t: torch.normal(0.0, captured), (torch.randn(B0),)),
2473*da0073e9SAndroid Build Coastguard Worker            (lambda t: torch.poisson(captured), (torch.rand(B0),)),
2474*da0073e9SAndroid Build Coastguard Worker            (lambda t: torch.rand_like(captured), (torch.rand(B0),)),
2475*da0073e9SAndroid Build Coastguard Worker            (lambda t: torch.randn_like(captured), (torch.rand(B0),)),
2476*da0073e9SAndroid Build Coastguard Worker            (lambda t: torch.randint_like(captured, 2), (torch.rand(B0),)),
2477*da0073e9SAndroid Build Coastguard Worker            (lambda t: torch.randint_like(captured, 0, 2), (torch.rand(B0),)),
2478*da0073e9SAndroid Build Coastguard Worker            # in-place on BatchedTensor
2479*da0073e9SAndroid Build Coastguard Worker            (lambda t: t.bernoulli_(), (torch.randn(B0, 1),)),
2480*da0073e9SAndroid Build Coastguard Worker            (lambda t: t.cauchy_(), (torch.randn(B0, 1),)),
2481*da0073e9SAndroid Build Coastguard Worker            (lambda t: t.exponential_(), (torch.randn(B0, 1),)),
2482*da0073e9SAndroid Build Coastguard Worker            (lambda t: t.geometric_(0.5), (torch.randn(B0, 1),)),
2483*da0073e9SAndroid Build Coastguard Worker            (lambda t: t.log_normal_(), (torch.randn(B0, 1),)),
2484*da0073e9SAndroid Build Coastguard Worker            (lambda t: t.normal_(), (torch.randn(B0, 1),)),
2485*da0073e9SAndroid Build Coastguard Worker            (lambda t: t.random_(), (torch.randn(B0, 1),)),
2486*da0073e9SAndroid Build Coastguard Worker            (lambda t: t.random_(0, 2), (torch.randn(B0, 1),)),
2487*da0073e9SAndroid Build Coastguard Worker            (lambda t: t.random_(2), (torch.randn(B0, 1),)),
2488*da0073e9SAndroid Build Coastguard Worker            (lambda t: t.uniform_(), (torch.randn(B0, 1),)),
2489*da0073e9SAndroid Build Coastguard Worker            # in-place on captured tensor
2490*da0073e9SAndroid Build Coastguard Worker            (lambda t: captured.bernoulli_(), (torch.randn(B0),)),
2491*da0073e9SAndroid Build Coastguard Worker            (lambda t: captured.cauchy_(), (torch.randn(B0),)),
2492*da0073e9SAndroid Build Coastguard Worker            (lambda t: captured.exponential_(), (torch.randn(B0),)),
2493*da0073e9SAndroid Build Coastguard Worker            (lambda t: captured.geometric_(0.5), (torch.randn(B0),)),
2494*da0073e9SAndroid Build Coastguard Worker            (lambda t: captured.log_normal_(), (torch.randn(B0),)),
2495*da0073e9SAndroid Build Coastguard Worker            (lambda t: captured.normal_(), (torch.randn(B0),)),
2496*da0073e9SAndroid Build Coastguard Worker            (lambda t: captured.random_(), (torch.randn(B0),)),
2497*da0073e9SAndroid Build Coastguard Worker            (lambda t: captured.random_(0, 2), (torch.randn(B0),)),
2498*da0073e9SAndroid Build Coastguard Worker            (lambda t: captured.random_(2), (torch.randn(B0),)),
2499*da0073e9SAndroid Build Coastguard Worker            (lambda t: captured.uniform_(), (torch.randn(B0),)),
2500*da0073e9SAndroid Build Coastguard Worker            # factory functions
2501*da0073e9SAndroid Build Coastguard Worker            (lambda t: torch.rand(1), (torch.randn(B0),)),
2502*da0073e9SAndroid Build Coastguard Worker            (lambda t: torch.randn(1), (torch.randn(B0),)),
2503*da0073e9SAndroid Build Coastguard Worker            (lambda t: torch.randint(5, [1]), (torch.randn(B0),)),
2504*da0073e9SAndroid Build Coastguard Worker            (lambda t: torch.randperm(5), (torch.randn(B0),)),
2505*da0073e9SAndroid Build Coastguard Worker        ]
2506*da0073e9SAndroid Build Coastguard Worker        for op, args in random_ops:
2507*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
2508*da0073e9SAndroid Build Coastguard Worker                RuntimeError, "vmap: We do not yet support calling random operations"
2509*da0073e9SAndroid Build Coastguard Worker            ):
2510*da0073e9SAndroid Build Coastguard Worker                vmap(op)(*args)
2511*da0073e9SAndroid Build Coastguard Worker
2512*da0073e9SAndroid Build Coastguard Worker
2513*da0073e9SAndroid Build Coastguard Workerdef construct_v(output, batch_size):
2514*da0073e9SAndroid Build Coastguard Worker    return torch.randn(
2515*da0073e9SAndroid Build Coastguard Worker        batch_size, *output.shape, dtype=output.dtype, device=output.device
2516*da0073e9SAndroid Build Coastguard Worker    )
2517*da0073e9SAndroid Build Coastguard Worker
2518*da0073e9SAndroid Build Coastguard Worker
2519*da0073e9SAndroid Build Coastguard Workerdef as_tuple(x):
2520*da0073e9SAndroid Build Coastguard Worker    if isinstance(x, tuple):
2521*da0073e9SAndroid Build Coastguard Worker        return x
2522*da0073e9SAndroid Build Coastguard Worker    elif isinstance(x, list):
2523*da0073e9SAndroid Build Coastguard Worker        return tuple(x)
2524*da0073e9SAndroid Build Coastguard Worker    else:
2525*da0073e9SAndroid Build Coastguard Worker        return (x,)
2526*da0073e9SAndroid Build Coastguard Worker
2527*da0073e9SAndroid Build Coastguard Worker
2528*da0073e9SAndroid Build Coastguard Workerdef differentiable(args):
2529*da0073e9SAndroid Build Coastguard Worker    return tuple(
2530*da0073e9SAndroid Build Coastguard Worker        arg
2531*da0073e9SAndroid Build Coastguard Worker        for arg in as_tuple(args)
2532*da0073e9SAndroid Build Coastguard Worker        if isinstance(arg, torch.Tensor) and arg.requires_grad
2533*da0073e9SAndroid Build Coastguard Worker    )
2534*da0073e9SAndroid Build Coastguard Worker
2535*da0073e9SAndroid Build Coastguard Worker
2536*da0073e9SAndroid Build Coastguard Workerdef _get_rand_no_zeros(*args, **kwargs):
2537*da0073e9SAndroid Build Coastguard Worker    requires_grad = kwargs.get("requires_grad", False)
2538*da0073e9SAndroid Build Coastguard Worker    kwargs_without_requires_grad = kwargs.copy()
2539*da0073e9SAndroid Build Coastguard Worker    kwargs_without_requires_grad["requires_grad"] = False
2540*da0073e9SAndroid Build Coastguard Worker    result = torch.rand(*args, **kwargs_without_requires_grad)
2541*da0073e9SAndroid Build Coastguard Worker    return result.clamp_min_(0.1).requires_grad_(requires_grad)
2542*da0073e9SAndroid Build Coastguard Worker
2543*da0073e9SAndroid Build Coastguard Worker
2544*da0073e9SAndroid Build Coastguard Workerclass TestVmapBatchedGradientLegacy(Namespace.TestVmapBaseLegacy):
2545*da0073e9SAndroid Build Coastguard Worker    def _vmap_test(self, *args, **kwargs):
2546*da0073e9SAndroid Build Coastguard Worker        return _vmap_test(self, *args, **kwargs)
2547*da0073e9SAndroid Build Coastguard Worker
2548*da0073e9SAndroid Build Coastguard Worker    # Tests batched gradient computation of outputs = op(*args, **kwargs)
2549*da0073e9SAndroid Build Coastguard Worker    # by comparing it to a sequential map+stack fallback.
2550*da0073e9SAndroid Build Coastguard Worker    #
2551*da0073e9SAndroid Build Coastguard Worker    # output_process_fn: a function that maps the outputs to the part
2552*da0073e9SAndroid Build Coastguard Worker    #       that should be differentiated.
2553*da0073e9SAndroid Build Coastguard Worker    # batch_size: the batch dim size for the batched grad
2554*da0073e9SAndroid Build Coastguard Worker    def _batched_grad_test(
2555*da0073e9SAndroid Build Coastguard Worker        self, op, args, kwargs=None, output_process_fn=lambda x: x, batch_size=3
2556*da0073e9SAndroid Build Coastguard Worker    ):
2557*da0073e9SAndroid Build Coastguard Worker        if kwargs is None:
2558*da0073e9SAndroid Build Coastguard Worker            kwargs = {}
2559*da0073e9SAndroid Build Coastguard Worker        outputs = op(*args, **kwargs)
2560*da0073e9SAndroid Build Coastguard Worker        outputs = differentiable(output_process_fn(outputs))
2561*da0073e9SAndroid Build Coastguard Worker        batched_vectors = tuple(construct_v(out, batch_size) for out in outputs)
2562*da0073e9SAndroid Build Coastguard Worker
2563*da0073e9SAndroid Build Coastguard Worker        def vector_jacobian_product(*vectors):
2564*da0073e9SAndroid Build Coastguard Worker            return torch.autograd.grad(
2565*da0073e9SAndroid Build Coastguard Worker                outputs, differentiable(args), vectors, retain_graph=True
2566*da0073e9SAndroid Build Coastguard Worker            )
2567*da0073e9SAndroid Build Coastguard Worker
2568*da0073e9SAndroid Build Coastguard Worker        self._vmap_test(
2569*da0073e9SAndroid Build Coastguard Worker            vector_jacobian_product, batched_vectors, check_propagates_grad=False
2570*da0073e9SAndroid Build Coastguard Worker        )
2571*da0073e9SAndroid Build Coastguard Worker
2572*da0073e9SAndroid Build Coastguard Worker    # Tests batched second grad computation of outputs = op(*args, **kwargs).
2573*da0073e9SAndroid Build Coastguard Worker    # by comparing it to a sequential map+stack fallback.
2574*da0073e9SAndroid Build Coastguard Worker    #
2575*da0073e9SAndroid Build Coastguard Worker    # output_process_fn: a function that maps the outputs to the part
2576*da0073e9SAndroid Build Coastguard Worker    #       that should be differentiated.
2577*da0073e9SAndroid Build Coastguard Worker    # batch_size: the batch dim size for the batched grad
2578*da0073e9SAndroid Build Coastguard Worker    #
2579*da0073e9SAndroid Build Coastguard Worker    # NB: we only test computing batched gradients in the second gradient
2580*da0073e9SAndroid Build Coastguard Worker    # computation. One specific use case that does this is computing the hessian
2581*da0073e9SAndroid Build Coastguard Worker    # matrix of a scalar-valued function; this is useful in Bayesian Logistic
2582*da0073e9SAndroid Build Coastguard Worker    # Regression.
2583*da0073e9SAndroid Build Coastguard Worker    # It might be useful to have a test that computes batched first gradients and
2584*da0073e9SAndroid Build Coastguard Worker    # then uses those to compute batched second gradients in the future.
2585*da0073e9SAndroid Build Coastguard Worker    def _batched_grad_grad_test(
2586*da0073e9SAndroid Build Coastguard Worker        self, op, args, kwargs=None, output_process_fn=lambda x: x, batch_size=3
2587*da0073e9SAndroid Build Coastguard Worker    ):
2588*da0073e9SAndroid Build Coastguard Worker        if kwargs is None:
2589*da0073e9SAndroid Build Coastguard Worker            kwargs = {}
2590*da0073e9SAndroid Build Coastguard Worker        outputs = op(*args, **kwargs)
2591*da0073e9SAndroid Build Coastguard Worker        outputs = differentiable(output_process_fn(outputs))
2592*da0073e9SAndroid Build Coastguard Worker        ones = tuple(torch.ones_like(out) for out in outputs)
2593*da0073e9SAndroid Build Coastguard Worker        # Same thing as summing together all of the outputs and calling .backward()
2594*da0073e9SAndroid Build Coastguard Worker        first_grads = torch.autograd.grad(
2595*da0073e9SAndroid Build Coastguard Worker            outputs, differentiable(args), ones, create_graph=True
2596*da0073e9SAndroid Build Coastguard Worker        )
2597*da0073e9SAndroid Build Coastguard Worker        first_grads = differentiable(first_grads)
2598*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(
2599*da0073e9SAndroid Build Coastguard Worker            len(first_grads), 0, "None of the first grads depend on the input!"
2600*da0073e9SAndroid Build Coastguard Worker        )
2601*da0073e9SAndroid Build Coastguard Worker
2602*da0073e9SAndroid Build Coastguard Worker        batched_vectors = tuple(construct_v(grad, batch_size) for grad in first_grads)
2603*da0073e9SAndroid Build Coastguard Worker
2604*da0073e9SAndroid Build Coastguard Worker        def vector_hessian_product(*vectors):
2605*da0073e9SAndroid Build Coastguard Worker            outputs = torch.autograd.grad(
2606*da0073e9SAndroid Build Coastguard Worker                first_grads,
2607*da0073e9SAndroid Build Coastguard Worker                differentiable(args),
2608*da0073e9SAndroid Build Coastguard Worker                vectors,
2609*da0073e9SAndroid Build Coastguard Worker                retain_graph=True,
2610*da0073e9SAndroid Build Coastguard Worker                allow_unused=True,
2611*da0073e9SAndroid Build Coastguard Worker            )
2612*da0073e9SAndroid Build Coastguard Worker            outputs = tuple(out for out in outputs if out is not None)
2613*da0073e9SAndroid Build Coastguard Worker            assert len(outputs) > 0
2614*da0073e9SAndroid Build Coastguard Worker            return outputs
2615*da0073e9SAndroid Build Coastguard Worker
2616*da0073e9SAndroid Build Coastguard Worker        self._vmap_test(
2617*da0073e9SAndroid Build Coastguard Worker            vector_hessian_product, batched_vectors, check_propagates_grad=False
2618*da0073e9SAndroid Build Coastguard Worker        )
2619*da0073e9SAndroid Build Coastguard Worker
2620*da0073e9SAndroid Build Coastguard Worker    def _test_arithmetic(self, op, device, test_grad_grad=True):
2621*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 3, requires_grad=True, device=device)
2622*da0073e9SAndroid Build Coastguard Worker        y = _get_rand_no_zeros(2, 3, device=device, requires_grad=True)
2623*da0073e9SAndroid Build Coastguard Worker        scalar = 3.14
2624*da0073e9SAndroid Build Coastguard Worker        self._batched_grad_test(op, (x, y))
2625*da0073e9SAndroid Build Coastguard Worker        self._batched_grad_test(op, (scalar, y))
2626*da0073e9SAndroid Build Coastguard Worker        self._batched_grad_test(op, (x, scalar))
2627*da0073e9SAndroid Build Coastguard Worker
2628*da0073e9SAndroid Build Coastguard Worker        if test_grad_grad:
2629*da0073e9SAndroid Build Coastguard Worker            self._batched_grad_grad_test(op, (x, y))
2630*da0073e9SAndroid Build Coastguard Worker
2631*da0073e9SAndroid Build Coastguard Worker    def test_add(self, device):
2632*da0073e9SAndroid Build Coastguard Worker        self._test_arithmetic(torch.add, device, test_grad_grad=False)
2633*da0073e9SAndroid Build Coastguard Worker        self._test_arithmetic(lambda x, y: x + y, device, test_grad_grad=False)
2634*da0073e9SAndroid Build Coastguard Worker
2635*da0073e9SAndroid Build Coastguard Worker    def test_sub(self, device):
2636*da0073e9SAndroid Build Coastguard Worker        self._test_arithmetic(torch.sub, device, test_grad_grad=False)
2637*da0073e9SAndroid Build Coastguard Worker        self._test_arithmetic(lambda x, y: x - y, device, test_grad_grad=False)
2638*da0073e9SAndroid Build Coastguard Worker
2639*da0073e9SAndroid Build Coastguard Worker    def test_mul(self, device):
2640*da0073e9SAndroid Build Coastguard Worker        self._test_arithmetic(torch.mul, device)
2641*da0073e9SAndroid Build Coastguard Worker        self._test_arithmetic(lambda x, y: x * y, device)
2642*da0073e9SAndroid Build Coastguard Worker
2643*da0073e9SAndroid Build Coastguard Worker    def test_div(self, device):
2644*da0073e9SAndroid Build Coastguard Worker        self._test_arithmetic(torch.div, device)
2645*da0073e9SAndroid Build Coastguard Worker        self._test_arithmetic(lambda x, y: x / y, device)
2646*da0073e9SAndroid Build Coastguard Worker
2647*da0073e9SAndroid Build Coastguard Worker    @allowVmapFallbackUsage
2648*da0073e9SAndroid Build Coastguard Worker    def test_binary_cross_entropy(self, device):
2649*da0073e9SAndroid Build Coastguard Worker        x = torch.sigmoid(torch.randn(3, 2, device=device, requires_grad=True))
2650*da0073e9SAndroid Build Coastguard Worker        target = torch.rand(3, 2, device=device)
2651*da0073e9SAndroid Build Coastguard Worker
2652*da0073e9SAndroid Build Coastguard Worker        op = functools.partial(F.binary_cross_entropy, target=target)
2653*da0073e9SAndroid Build Coastguard Worker
2654*da0073e9SAndroid Build Coastguard Worker        self._batched_grad_test(op, (x,), {})
2655*da0073e9SAndroid Build Coastguard Worker        self._batched_grad_grad_test(op, (x,), {})
2656*da0073e9SAndroid Build Coastguard Worker
2657*da0073e9SAndroid Build Coastguard Worker    def test_expand(self, device):
2658*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 3, device=device, requires_grad=True)
2659*da0073e9SAndroid Build Coastguard Worker
2660*da0073e9SAndroid Build Coastguard Worker        def op(x):
2661*da0073e9SAndroid Build Coastguard Worker            return x.expand(5, 5, 2, 3)
2662*da0073e9SAndroid Build Coastguard Worker
2663*da0073e9SAndroid Build Coastguard Worker        self._batched_grad_test(op, (x,))
2664*da0073e9SAndroid Build Coastguard Worker
2665*da0073e9SAndroid Build Coastguard Worker    @allowVmapFallbackUsage
2666*da0073e9SAndroid Build Coastguard Worker    def test_index(self, device):
2667*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 3, requires_grad=True, device=device)
2668*da0073e9SAndroid Build Coastguard Worker        index = torch.tensor([[0, 0], [1, 1]], device=device)
2669*da0073e9SAndroid Build Coastguard Worker
2670*da0073e9SAndroid Build Coastguard Worker        def op(x):
2671*da0073e9SAndroid Build Coastguard Worker            y = x * x
2672*da0073e9SAndroid Build Coastguard Worker            return y[index]
2673*da0073e9SAndroid Build Coastguard Worker
2674*da0073e9SAndroid Build Coastguard Worker        self._batched_grad_test(op, (x,))
2675*da0073e9SAndroid Build Coastguard Worker        self._batched_grad_grad_test(op, (x,))
2676*da0073e9SAndroid Build Coastguard Worker
2677*da0073e9SAndroid Build Coastguard Worker    def test_lgamma(self, device):
2678*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 3, requires_grad=True, device=device)
2679*da0073e9SAndroid Build Coastguard Worker        self._batched_grad_test(Tensor.lgamma, (x,))
2680*da0073e9SAndroid Build Coastguard Worker        self._batched_grad_grad_test(Tensor.lgamma, (x,))
2681*da0073e9SAndroid Build Coastguard Worker
2682*da0073e9SAndroid Build Coastguard Worker    def test_log(self, device):
2683*da0073e9SAndroid Build Coastguard Worker        x = _get_rand_no_zeros(2, 3, device=device, requires_grad=True)
2684*da0073e9SAndroid Build Coastguard Worker        self._batched_grad_test(torch.log, (x,))
2685*da0073e9SAndroid Build Coastguard Worker        self._batched_grad_grad_test(torch.log, (x,))
2686*da0073e9SAndroid Build Coastguard Worker
2687*da0073e9SAndroid Build Coastguard Worker    def test_logsumexp(self, device):
2688*da0073e9SAndroid Build Coastguard Worker        x = _get_rand_no_zeros(2, 3, device=device, requires_grad=True)
2689*da0073e9SAndroid Build Coastguard Worker
2690*da0073e9SAndroid Build Coastguard Worker        def op(x):
2691*da0073e9SAndroid Build Coastguard Worker            return torch.logsumexp(x, -1)
2692*da0073e9SAndroid Build Coastguard Worker
2693*da0073e9SAndroid Build Coastguard Worker        self._batched_grad_test(op, (x,))
2694*da0073e9SAndroid Build Coastguard Worker        self._batched_grad_grad_test(op, (x,))
2695*da0073e9SAndroid Build Coastguard Worker
2696*da0073e9SAndroid Build Coastguard Worker    def test_log1p(self, device):
2697*da0073e9SAndroid Build Coastguard Worker        x = _get_rand_no_zeros(2, 3, device=device, requires_grad=True)
2698*da0073e9SAndroid Build Coastguard Worker        self._batched_grad_test(torch.log1p, (x,))
2699*da0073e9SAndroid Build Coastguard Worker        self._batched_grad_grad_test(torch.log1p, (x,))
2700*da0073e9SAndroid Build Coastguard Worker
2701*da0073e9SAndroid Build Coastguard Worker    @allowVmapFallbackUsage
2702*da0073e9SAndroid Build Coastguard Worker    def test_max(self, device):
2703*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 3, requires_grad=True, device=device)
2704*da0073e9SAndroid Build Coastguard Worker        self._batched_grad_test(torch.max, (x,))
2705*da0073e9SAndroid Build Coastguard Worker
2706*da0073e9SAndroid Build Coastguard Worker    @allowVmapFallbackUsage
2707*da0073e9SAndroid Build Coastguard Worker    def test_median(self, device):
2708*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 3, requires_grad=True, device=device)
2709*da0073e9SAndroid Build Coastguard Worker        self._batched_grad_test(torch.median, (x,))
2710*da0073e9SAndroid Build Coastguard Worker
2711*da0073e9SAndroid Build Coastguard Worker    @allowVmapFallbackUsage
2712*da0073e9SAndroid Build Coastguard Worker    def test_min(self, device):
2713*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 3, requires_grad=True, device=device)
2714*da0073e9SAndroid Build Coastguard Worker        self._batched_grad_test(torch.min, (x,))
2715*da0073e9SAndroid Build Coastguard Worker
2716*da0073e9SAndroid Build Coastguard Worker    def test_permute(self, device):
2717*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 3, 5, requires_grad=True, device=device)
2718*da0073e9SAndroid Build Coastguard Worker
2719*da0073e9SAndroid Build Coastguard Worker        def op(x):
2720*da0073e9SAndroid Build Coastguard Worker            return x.permute(2, 0, 1)
2721*da0073e9SAndroid Build Coastguard Worker
2722*da0073e9SAndroid Build Coastguard Worker        self._batched_grad_test(op, (x,))
2723*da0073e9SAndroid Build Coastguard Worker
2724*da0073e9SAndroid Build Coastguard Worker    def test_reshape(self, device):
2725*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 3, 5, requires_grad=True, device=device)
2726*da0073e9SAndroid Build Coastguard Worker
2727*da0073e9SAndroid Build Coastguard Worker        def op(x):
2728*da0073e9SAndroid Build Coastguard Worker            return x.reshape([2 * 3, 5])
2729*da0073e9SAndroid Build Coastguard Worker
2730*da0073e9SAndroid Build Coastguard Worker        self._batched_grad_test(op, (x,))
2731*da0073e9SAndroid Build Coastguard Worker
2732*da0073e9SAndroid Build Coastguard Worker    def test_sigmoid(self, device):
2733*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 3, requires_grad=True, device=device)
2734*da0073e9SAndroid Build Coastguard Worker        self._batched_grad_test(Tensor.sigmoid, (x,))
2735*da0073e9SAndroid Build Coastguard Worker        self._batched_grad_grad_test(Tensor.sigmoid, (x,))
2736*da0073e9SAndroid Build Coastguard Worker
2737*da0073e9SAndroid Build Coastguard Worker    def test_stack(self, device):
2738*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 3, device=device, requires_grad=True)
2739*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(2, 3, device=device, requires_grad=True)
2740*da0073e9SAndroid Build Coastguard Worker
2741*da0073e9SAndroid Build Coastguard Worker        def op(x, y):
2742*da0073e9SAndroid Build Coastguard Worker            return torch.stack([x, y])
2743*da0073e9SAndroid Build Coastguard Worker
2744*da0073e9SAndroid Build Coastguard Worker        self._batched_grad_test(op, (x, y))
2745*da0073e9SAndroid Build Coastguard Worker
2746*da0073e9SAndroid Build Coastguard Worker    def test_select(self, device):
2747*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 3, device=device, requires_grad=True)
2748*da0073e9SAndroid Build Coastguard Worker        self._batched_grad_test(lambda x: x[1], (x,))
2749*da0073e9SAndroid Build Coastguard Worker        self._batched_grad_test(lambda x: x.select(1, 2), (x,))
2750*da0073e9SAndroid Build Coastguard Worker        self._batched_grad_test(lambda x: x.select(-1, 0), (x,))
2751*da0073e9SAndroid Build Coastguard Worker
2752*da0073e9SAndroid Build Coastguard Worker    def test_slice(self, device):
2753*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 3, 5, device=device, requires_grad=True)
2754*da0073e9SAndroid Build Coastguard Worker        self._batched_grad_test(lambda x: x[0:1], (x,))
2755*da0073e9SAndroid Build Coastguard Worker        self._batched_grad_test(lambda x: x[:, 1:3], (x,))
2756*da0073e9SAndroid Build Coastguard Worker        self._batched_grad_test(lambda x: x[..., 1:3], (x,))
2757*da0073e9SAndroid Build Coastguard Worker
2758*da0073e9SAndroid Build Coastguard Worker    def test_trace(self, device):
2759*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 3, device=device, requires_grad=True)
2760*da0073e9SAndroid Build Coastguard Worker        self._batched_grad_test(Tensor.trace, (x,))
2761*da0073e9SAndroid Build Coastguard Worker
2762*da0073e9SAndroid Build Coastguard Worker    def test_threshold(self, device):
2763*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 3, device=device, requires_grad=True)
2764*da0073e9SAndroid Build Coastguard Worker        self._batched_grad_test(lambda x: F.threshold(x, 0.5, 0.0), (x,))
2765*da0073e9SAndroid Build Coastguard Worker
2766*da0073e9SAndroid Build Coastguard Worker    @allowVmapFallbackUsage
2767*da0073e9SAndroid Build Coastguard Worker    def test_inplace_on_view(self, device):
2768*da0073e9SAndroid Build Coastguard Worker        leaf = torch.randn(4, 5, requires_grad=True)
2769*da0073e9SAndroid Build Coastguard Worker
2770*da0073e9SAndroid Build Coastguard Worker        def func(leaf):
2771*da0073e9SAndroid Build Coastguard Worker            # Make sure the function is non-trivially twice differentiable
2772*da0073e9SAndroid Build Coastguard Worker            base = leaf * leaf
2773*da0073e9SAndroid Build Coastguard Worker            view = base[0]
2774*da0073e9SAndroid Build Coastguard Worker            view.cos_()
2775*da0073e9SAndroid Build Coastguard Worker            return view
2776*da0073e9SAndroid Build Coastguard Worker
2777*da0073e9SAndroid Build Coastguard Worker        self._batched_grad_test(func, (leaf,), {})
2778*da0073e9SAndroid Build Coastguard Worker        self._batched_grad_grad_test(func, (leaf,), {})
2779*da0073e9SAndroid Build Coastguard Worker
2780*da0073e9SAndroid Build Coastguard Worker    @allowVmapFallbackUsage
2781*da0073e9SAndroid Build Coastguard Worker    def test_inplace_manyview(self, device):
2782*da0073e9SAndroid Build Coastguard Worker        leaf = torch.randn(4, 4, 5, requires_grad=True)
2783*da0073e9SAndroid Build Coastguard Worker
2784*da0073e9SAndroid Build Coastguard Worker        def func(leaf):
2785*da0073e9SAndroid Build Coastguard Worker            # Make sure the function is non-trivially twice differentiable
2786*da0073e9SAndroid Build Coastguard Worker            base = leaf * leaf
2787*da0073e9SAndroid Build Coastguard Worker            view = base.transpose(0, 2)
2788*da0073e9SAndroid Build Coastguard Worker            view = view[1]
2789*da0073e9SAndroid Build Coastguard Worker            view = view.diagonal()
2790*da0073e9SAndroid Build Coastguard Worker            view = view[::2]
2791*da0073e9SAndroid Build Coastguard Worker            view.cos_()
2792*da0073e9SAndroid Build Coastguard Worker            return view
2793*da0073e9SAndroid Build Coastguard Worker
2794*da0073e9SAndroid Build Coastguard Worker        self._batched_grad_test(func, (leaf,), {})
2795*da0073e9SAndroid Build Coastguard Worker        self._batched_grad_grad_test(func, (leaf,), {})
2796*da0073e9SAndroid Build Coastguard Worker
2797*da0073e9SAndroid Build Coastguard Worker    def test_diagonal(self, device):
2798*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4, 5, device=device, requires_grad=True)
2799*da0073e9SAndroid Build Coastguard Worker        self._batched_grad_test(lambda x: x.diagonal(1, 0, 1), (x,))
2800*da0073e9SAndroid Build Coastguard Worker
2801*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3, 4, 5, device=device, requires_grad=True)
2802*da0073e9SAndroid Build Coastguard Worker        self._batched_grad_test(lambda x: x.diagonal(0, -1, -2), (x,))
2803*da0073e9SAndroid Build Coastguard Worker
2804*da0073e9SAndroid Build Coastguard Worker    @allowVmapFallbackUsage
2805*da0073e9SAndroid Build Coastguard Worker    def test_unrelated_output(self, device):
2806*da0073e9SAndroid Build Coastguard Worker        B0 = 3
2807*da0073e9SAndroid Build Coastguard Worker        x = torch.randn([], requires_grad=True)
2808*da0073e9SAndroid Build Coastguard Worker        y = torch.randn([], requires_grad=True)
2809*da0073e9SAndroid Build Coastguard Worker        gy = torch.randn(B0, requires_grad=True)
2810*da0073e9SAndroid Build Coastguard Worker
2811*da0073e9SAndroid Build Coastguard Worker        def vjp(v):
2812*da0073e9SAndroid Build Coastguard Worker            (res,) = torch.autograd.grad(y, x, v, allow_unused=True)
2813*da0073e9SAndroid Build Coastguard Worker            return torch.zeros_like(x) if res is None else res
2814*da0073e9SAndroid Build Coastguard Worker
2815*da0073e9SAndroid Build Coastguard Worker        result = vmap(vjp)(gy)
2816*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result, torch.zeros(B0, *x.shape, device=device))
2817*da0073e9SAndroid Build Coastguard Worker
2818*da0073e9SAndroid Build Coastguard Worker    @allowVmapFallbackUsage
2819*da0073e9SAndroid Build Coastguard Worker    def test_unrelated_output_multiple_grad(self, device):
2820*da0073e9SAndroid Build Coastguard Worker        B0 = 3
2821*da0073e9SAndroid Build Coastguard Worker        x = torch.randn([], requires_grad=True)
2822*da0073e9SAndroid Build Coastguard Worker        y = torch.randn([], requires_grad=True)
2823*da0073e9SAndroid Build Coastguard Worker        gy = torch.randn(B0, requires_grad=True)
2824*da0073e9SAndroid Build Coastguard Worker
2825*da0073e9SAndroid Build Coastguard Worker        def vjp(v):
2826*da0073e9SAndroid Build Coastguard Worker            (res,) = torch.autograd.grad(y, x, v, allow_unused=True)
2827*da0073e9SAndroid Build Coastguard Worker            return torch.zeros_like(x) if res is None else res
2828*da0073e9SAndroid Build Coastguard Worker
2829*da0073e9SAndroid Build Coastguard Worker        _ = vjp(gy[0])
2830*da0073e9SAndroid Build Coastguard Worker        result = vmap(vjp)(gy)
2831*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result, torch.zeros(B0, *x.shape, device=device))
2832*da0073e9SAndroid Build Coastguard Worker
2833*da0073e9SAndroid Build Coastguard Worker
2834*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestVmapBatchedGradientLegacy, globals(), None)
2835*da0073e9SAndroid Build Coastguard Worker
2836*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
2837*da0073e9SAndroid Build Coastguard Worker    run_tests()
2838