xref: /aosp_15_r20/external/pytorch/torch/testing/_internal/distributed/nn/api/remote_module_test.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2
3import enum
4from typing import Tuple
5
6import torch
7import torch.distributed.rpc as rpc
8import torch.testing._internal.dist_utils as dist_utils
9from torch import Tensor, nn
10from torch._jit_internal import Future
11from torch.distributed.nn import RemoteModule
12from torch.distributed.nn.api.remote_module import _REMOTE_MODULE_PICKLED_ATTRIBUTES
13from torch.distributed.nn.api.remote_module import _RemoteModule
14from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
15from torch.testing._internal.common_utils import TemporaryFileName
16from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import (
17    RpcAgentTestFixture,
18)
19
20
21_PARAM_VAL = torch.nn.Parameter(torch.ones(1))
22
23
24# RPC handler for querying the device on the destination worker.
25def remote_device(module_rref):
26    for param in module_rref.local_value().parameters():
27        return param.device
28
29
30# RPC handler for querying __dict__ on the destination worker.
31def remote_module_attributes(remote_module):
32    return remote_module.__dict__
33
34
35# RPC handler for running forward on the destination worker.
36def remote_forward(remote_module, args):
37    return remote_module.forward(*args)
38
39# RPC handler for running forward_async on the destination worker.
40def remote_forward_async(remote_module, args):
41    # Since future cannot be pickled and sent over the RPC layer,
42    # have to wait and behave just like ``forward_sync``.
43    return remote_module.forward_async(*args).wait()
44
45# RPC handler for getting training mode on the destination worker.
46def get_remote_training_arg(module_rref):
47    return module_rref.local_value().training
48
49class ModuleCreationMode(enum.Enum):
50    MODULE_CTOR_WITH_INTERFACE = "module_ctor_with_interface"
51    MODULE_CTOR = "module_ctor"
52
53
54@torch.jit.interface
55class MyModuleInterface:
56    def forward(
57        self, tensor: Tensor, number: int, word: str = "default"
58    ) -> Tuple[str, int, Tensor]:
59        # pyre-ignore[7]: Pyre and torch.jit.interface don't mix well
60        pass
61
62
63@torch.jit.interface
64class RemoteMyModuleInterface:
65    def forward(
66        self, tensor: Tensor, number: int, word: str = "default"
67    ) -> Tuple[str, int, Tensor]:
68        # pyre-ignore[7]: Pyre and torch.jit.interface don't mix well
69        pass
70
71    def forward_async(
72        self, tensor: Tensor, number: int, word: str = "default"
73    ) -> Future[Tuple[str, int, Tensor]]:
74        pass
75
76
77class MyModule(nn.Module):
78    def __init__(self, first_arg, first_kwarg=-1):
79        super().__init__()
80        self.param1 = _PARAM_VAL
81
82    def forward(
83        self, tensor: Tensor, number: int, word: str = "default"
84    ) -> Tuple[str, int, Tensor]:
85        return word, number, tensor
86
87
88class BadModule:
89    def __init__(self, first_arg, first_kwarg=-1):
90        pass
91
92
93def create_scripted_module(first_arg, first_kwarg=-1):
94    module = MyModule(first_arg, first_kwarg=first_kwarg)
95    scripted_module = torch.jit.script(module)
96    return scripted_module
97
98
99# Common utils for both CPU and CUDA test suites
100class CommonRemoteModuleTest(RpcAgentTestFixture):
101    @property
102    def world_size(self):  # Override setting in RpcAgentTestFixture
103        return 2
104
105    @staticmethod
106    def _create_remote_module_iter(remote_device, modes=None):
107        if modes is None:
108            modes = ModuleCreationMode.__members__.values()
109
110        args = (1,)
111        kwargs = dict(first_kwarg=2)
112
113        if ModuleCreationMode.MODULE_CTOR in modes:
114            remote_module = RemoteModule(remote_device, MyModule, args, kwargs)
115            yield remote_module
116
117        if ModuleCreationMode.MODULE_CTOR_WITH_INTERFACE in modes:
118            remote_module = _RemoteModule(
119                remote_device,
120                create_scripted_module,
121                args,
122                kwargs,
123                _module_interface_cls=MyModuleInterface,
124            )
125            scripted_remote_module = torch.jit.script(remote_module)
126            yield scripted_remote_module
127
128
129class RemoteModuleTest(CommonRemoteModuleTest):
130    @dist_utils.dist_init
131    def test_bad_module(self):
132        if self.rank != 0:
133            return
134        dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
135        remote_device = f"{dst_worker_name}/cpu"
136        args = (1,)
137        kwargs = dict(first_kwarg=2)
138
139        with self.assertRaisesRegex(
140            ValueError,
141            r"Expect `module_cls\(\*args, \*\*kwargs\)` returns an instance of <class nn.Module>,",
142        ):
143            RemoteModule(remote_device, BadModule, args, kwargs).forward()
144
145        with self.assertRaisesRegex(
146            ValueError,
147            r"Expect `module_cls\(\*args, \*\*kwargs\)` returns an instance of <class nn.Module>,",
148        ):
149            RemoteModule(remote_device, BadModule, args, kwargs).forward()
150
151
152    @dist_utils.dist_init
153    def test_forward_async(self):
154        if self.rank != 0:
155            return
156        dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
157        args = (torch.ones(1), 2, "3")
158        for remote_module in self._create_remote_module_iter(dst_worker_name):
159            ret_fut = remote_module.forward_async(*args)
160            ret = ret_fut.wait()
161            self.assertEqual(ret, tuple(reversed(args)))
162
163    @dist_utils.dist_init
164    def test_forward_async_script(self):
165        if self.rank != 0:
166            return
167        dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
168
169        scripted_remote_module = next(
170            self._create_remote_module_iter(
171                dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR_WITH_INTERFACE]
172            )
173        )
174
175        @torch.jit.script
176        def run_forward_async(scripted_remote_module: RemoteMyModuleInterface):
177            ret_fut = scripted_remote_module.forward_async(torch.ones(1), 2, "3")
178            ret = ret_fut.wait()
179            return ret
180
181        ret = run_forward_async(scripted_remote_module)
182
183        self.assertEqual(ret, ("3", 2, torch.ones(1)))
184
185    @dist_utils.dist_init
186    def test_forward_sync(self):
187        if self.rank != 0:
188            return
189        dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
190        args = (torch.ones(1), 2, "3")
191        for remote_module in self._create_remote_module_iter(dst_worker_name):
192            ret = remote_module.forward(*args)
193            self.assertEqual(ret, tuple(reversed(args)))
194
195    @dist_utils.dist_init
196    def test_forward_sync_script(self):
197        if self.rank != 0:
198            return
199        dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
200
201        scripted_remote_module = next(
202            self._create_remote_module_iter(
203                dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR_WITH_INTERFACE]
204            )
205        )
206
207        @torch.jit.script
208        def run_forward(scripted_remote_module: MyModuleInterface):
209            ret = scripted_remote_module.forward(torch.ones(1), 2, "3")
210            return ret
211
212        ret = run_forward(scripted_remote_module)
213
214        self.assertEqual(ret, ("3", 2, torch.ones(1)))
215
216    @dist_utils.dist_init
217    def test_forward_with_kwargs(self):
218        if self.rank != 0:
219            return
220        dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
221        args = (torch.ones(1), 2)
222        kwargs = dict(word="3")
223        # Only test Python nn.Module, because script module methods don't support taking kwargs.
224        for remote_module in self._create_remote_module_iter(
225            dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR]
226        ):
227            ret_fut = remote_module.forward_async(*args, **kwargs)
228            ret = ret_fut.wait()
229            self.assertEqual(ret, tuple(reversed(args + ("3",))))
230
231            ret = remote_module.forward(*args, **kwargs)
232            self.assertEqual(ret, tuple(reversed(args + ("3",))))
233
234    @dist_utils.dist_init
235    def test_remote_parameters(self):
236        if self.rank != 0:
237            return
238        dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
239
240        # Only test Python nn.Module, because script module methods don't support ``remote_parameters``.
241        for remote_module in self._create_remote_module_iter(
242            dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR]
243        ):
244            param_rrefs = remote_module.remote_parameters()
245            self.assertEqual(len(param_rrefs), 1)
246            self.assertTrue(torch.equal(param_rrefs[0].to_here(), _PARAM_VAL))
247
248    @dist_utils.dist_init
249    def test_get_module_rref(self):
250        if self.rank != 0:
251            return
252        dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
253
254        # Only test Python nn.Module, because script module methods don't support ``get_module_rref``.
255        for remote_module in self._create_remote_module_iter(
256            dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR]
257        ):
258            rref = remote_module.get_module_rref()
259            self.assertEqual(rref, remote_module.module_rref)
260            for param in rref.to_here().parameters():
261                self.assertTrue(torch.equal(param, _PARAM_VAL))
262
263    @dist_utils.dist_init
264    def test_train_eval(self):
265        if self.rank != 0:
266            return
267        dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
268
269        for remote_module in self._create_remote_module_iter(
270            dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR]
271        ):
272            remote_module.train()
273            ret1 = rpc.rpc_sync(dst_worker_name, get_remote_training_arg, args=(remote_module.get_module_rref(),))
274            self.assertEqual(ret1, True)
275
276            remote_module.eval()
277            ret2 = rpc.rpc_sync(dst_worker_name, get_remote_training_arg, args=(remote_module.get_module_rref(),))
278            self.assertEqual(ret2, False)
279
280    @dist_utils.dist_init
281    def test_unsupported_methods(self):
282        if self.rank != 0:
283            return
284        dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
285
286        for remote_module in self._create_remote_module_iter(
287            dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR]
288        ):
289            with self.assertRaisesRegex(
290                ValueError, r"Method ``register_buffer`` not supported for RemoteModule"
291            ):
292                remote_module.register_buffer("buffer", torch.ones(5))
293            with self.assertRaisesRegex(
294                ValueError,
295                r"Method ``register_parameter`` not supported for RemoteModule",
296            ):
297                remote_module.register_parameter(
298                    "param", torch.nn.Parameter(torch.ones(1))
299                )
300            with self.assertRaisesRegex(
301                ValueError, r"Method ``add_module`` not supported for RemoteModule"
302            ):
303                remote_module.add_module("empty", None)
304
305            with self.assertRaisesRegex(
306                ValueError, r"Method ``apply`` not supported for RemoteModule"
307            ):
308                fn = torch.rand((3, 3), requires_grad=False)
309                remote_module.apply(fn)
310
311            with self.assertRaisesRegex(
312                ValueError, r"Method ``cuda`` not supported for RemoteModule"
313            ):
314                remote_module.cuda()
315            with self.assertRaisesRegex(
316                ValueError, r"Method ``cpu`` not supported for RemoteModule"
317            ):
318                remote_module.cpu()
319            with self.assertRaisesRegex(
320                ValueError, r"Method ``type`` not supported for RemoteModule"
321            ):
322                remote_module.type(torch.FloatTensor)
323            with self.assertRaisesRegex(
324                ValueError, r"Method ``float`` not supported for RemoteModule"
325            ):
326                remote_module.float()
327            with self.assertRaisesRegex(
328                ValueError, r"Method ``double`` not supported for RemoteModule"
329            ):
330                remote_module.double()
331            with self.assertRaisesRegex(
332                ValueError, r"Method ``bfloat16`` not supported for RemoteModule"
333            ):
334                remote_module.bfloat16()
335            with self.assertRaisesRegex(
336                ValueError, r"Method ``to`` not supported for RemoteModule"
337            ):
338                remote_module.to("cpu", dtype=torch.int32)
339
340            def hook(module, grad_input, grad_output):
341                pass
342
343            with self.assertRaisesRegex(
344                ValueError,
345                r"Method ``register_backward_hook`` not supported for RemoteModule",
346            ):
347                remote_module.register_backward_hook(hook)
348            with self.assertRaisesRegex(
349                ValueError,
350                r"Method ``register_forward_pre_hook`` not supported for RemoteModule",
351            ):
352                remote_module.register_forward_pre_hook(hook)
353            with self.assertRaisesRegex(
354                ValueError,
355                r"Method ``register_forward_hook`` not supported for RemoteModule",
356            ):
357                remote_module.register_forward_hook(hook)
358
359            with self.assertRaisesRegex(
360                ValueError, r"Method ``state_dict`` not supported for RemoteModule"
361            ):
362                remote_module.state_dict()
363            with self.assertRaisesRegex(
364                ValueError, r"Method ``load_state_dict`` not supported for RemoteModule"
365            ):
366                remote_module.load_state_dict({})
367
368            with self.assertRaisesRegex(
369                ValueError,
370                r"Method ``parameters`` not supported for RemoteModule. Please use ``remote_parameters`` instead.",
371            ):
372                remote_module.parameters()
373            with self.assertRaisesRegex(
374                ValueError,
375                r"Method ``named_parameters`` not supported for RemoteModule",
376            ):
377                remote_module.named_parameters()
378            with self.assertRaisesRegex(
379                ValueError, r"Method ``buffers`` not supported for RemoteModule"
380            ):
381                remote_module.buffers()
382            with self.assertRaisesRegex(
383                ValueError, r"Method ``named_buffers`` not supported for RemoteModule"
384            ):
385                remote_module.named_buffers()
386            with self.assertRaisesRegex(
387                ValueError, r"Method ``children`` not supported for RemoteModule"
388            ):
389                remote_module.children()
390            with self.assertRaisesRegex(
391                ValueError, r"Method ``named_children`` not supported for RemoteModule"
392            ):
393                remote_module.named_children()
394            with self.assertRaisesRegex(
395                ValueError, r"Method ``modules`` not supported for RemoteModule"
396            ):
397                remote_module.modules()
398            with self.assertRaisesRegex(
399                ValueError, r"Method ``named_modules`` not supported for RemoteModule"
400            ):
401                remote_module.named_modules()
402
403            with self.assertRaisesRegex(
404                ValueError, r"Method ``requires_grad_`` not supported for RemoteModule"
405            ):
406                remote_module.requires_grad_()
407            with self.assertRaisesRegex(
408                ValueError, r"Method ``zero_grad`` not supported for RemoteModule"
409            ):
410                remote_module.zero_grad()
411            with self.assertRaisesRegex(
412                ValueError, r"Method ``share_memory`` not supported for RemoteModule"
413            ):
414                remote_module.share_memory()
415            with self.assertRaisesRegex(
416                ValueError, r"Method ``extra_repr`` not supported for RemoteModule"
417            ):
418                remote_module.extra_repr()
419
420    @dist_utils.dist_init
421    def test_send_remote_module_with_a_new_attribute_not_pickled_over_the_wire(self):
422        if self.rank != 0:
423            return
424        dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
425
426        # If a new attribute is added to this RemoteModule after the initialization,
427        # and it will be sent over the wire by RPC,
428        # this new field will not be pickled, because it's not specified in _REMOTE_MODULE_PICKLED_ATTRIBUTES.
429        # Note that adding a new attribute out of constructor should rarely happen.
430        # If a new attribute is added to RemoteModule constructor,
431        # there is a sanity check to enforce developers to add this attribute to either
432        # _REMOTE_MODULE_PICKLED_ATTRIBUTES or _REMOTE_MODULE_ATTRIBUTES_IGNORE_FOR_PICKLING.
433        for remote_module in self._create_remote_module_iter(
434            dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR]
435        ):
436            new_attr_name = "new_attr"
437            setattr(remote_module, new_attr_name, 1)
438
439            attrs = rpc.rpc_sync(
440                dst_worker_name, remote_module_attributes, (remote_module,)
441            )
442            self.assertNotIn(new_attr_name, attrs)
443
444    @dist_utils.dist_init
445    def test_remote_module_py_pickle_not_supported(self):
446        if self.rank != 0:
447            return
448        dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
449
450        for remote_module in self._create_remote_module_iter(
451            dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR]
452        ):
453            with TemporaryFileName() as fname:
454                with self.assertRaisesRegex(
455                    RuntimeError,
456                    "Cannot pickle RemoteModule in python pickler. RemoteModule can only be pickled when using RPC",
457                ):
458                    torch.save(remote_module, fname)
459
460    @dist_utils.dist_init
461    def test_remote_module_py_pickle_not_supported_script(self):
462        if self.rank != 0:
463            return
464        dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
465
466        for remote_module in self._create_remote_module_iter(
467            dst_worker_name, modes=[ModuleCreationMode.MODULE_CTOR_WITH_INTERFACE]
468        ):
469            with TemporaryFileName() as fname:
470                with self.assertRaisesRegex(torch.jit.Error, "can only be pickled when using RPC"):
471                    torch.save(remote_module, fname)
472
473
474class ThreeWorkersRemoteModuleTest(CommonRemoteModuleTest):
475    @property
476    def world_size(self):  # Override setting in CommonRemoteModuleTest
477        return 3
478
479    @dist_utils.dist_init
480    def test_send_remote_module_over_the_wire(self):
481        if self.rank != 0:
482            return
483        dst_worker1_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
484        dst_worker2_name = dist_utils.worker_name((self.rank + 2) % self.world_size)
485
486        # Unpickled attributes include both the inherent attributes of RemoteModule
487        # (not inherited from the superclass) and two installed methods.
488        expected_unpickled_attrs = list(_REMOTE_MODULE_PICKLED_ATTRIBUTES)
489        expected_unpickled_attrs.append("forward_async")
490        expected_unpickled_attrs.append("forward")
491
492        # Create a remote module on worker1 and then pass it to worker2 over the RPC layer.
493        for remote_module in self._create_remote_module_iter(
494            dst_worker1_name, modes=[ModuleCreationMode.MODULE_CTOR]
495        ):
496            # Test querying some simple attributes from worker2.
497            attrs = rpc.rpc_sync(
498                dst_worker2_name, remote_module_attributes, (remote_module,)
499            )
500            self.assertListEqual(list(attrs.keys()), expected_unpickled_attrs)
501            self.assertEqual(attrs["on"], "worker1")
502            self.assertEqual(attrs["device"], "cpu")
503            self.assertFalse(attrs["is_device_map_set"])
504            self.assertFalse(attrs["is_scriptable"])
505
506            # Test the installed methods on worker1's can be initiated by worker2 over RPC layer.
507            # NOTE: In practice a remote module should be directly stored on the worker that runs ``forward``` or ``forward_async``,
508            # not have another worker to initiate forward over the RPC layer.
509            args = (torch.ones(1), 2, "3")
510            ret1 = rpc.rpc_sync(dst_worker2_name, remote_forward, (remote_module, args))
511            self.assertEqual(ret1, tuple(reversed(args)))
512            ret2 = rpc.rpc_sync(
513                dst_worker2_name, remote_forward_async, (remote_module, args)
514            )
515            self.assertEqual(ret2, tuple(reversed(args)))
516
517    @dist_utils.dist_init
518    def test_send_remote_module_over_the_wire_script_not_supported(self):
519        if self.rank != 0:
520            return
521        dst_worker1_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
522        dst_worker2_name = dist_utils.worker_name((self.rank + 2) % self.world_size)
523
524        # Unpickled attributes include both the inherent attributes of RemoteModule
525        # (not inherited from the superclass) and two installed methods.
526        expected_unpickled_attrs = list(_REMOTE_MODULE_PICKLED_ATTRIBUTES)
527        expected_unpickled_attrs.append("forward_async")
528        expected_unpickled_attrs.append("forward")
529
530        with self.assertRaisesRegex(
531            RuntimeError, "Passing a script RemoteModule over RPC is not supported."
532        ):
533            # Create a remote module on worker1 and then pass it to worker2 over the RPC layer.
534            for remote_module in self._create_remote_module_iter(
535                dst_worker1_name, modes=[ModuleCreationMode.MODULE_CTOR_WITH_INTERFACE]
536            ):
537                # Test querying some simple attributes from worker2.
538                attrs = rpc.rpc_sync(
539                    dst_worker2_name, remote_module_attributes, (remote_module,)
540                )
541
542    @dist_utils.dist_init
543    def test_create_remote_module_from_module_rref(self):
544        if self.rank != 0:
545            return
546        dst_worker1_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
547        dst_worker2_name = dist_utils.worker_name((self.rank + 2) % self.world_size)
548
549        # Create a remote module on worker1 and then pass its `module_rref` to worker2 over the RPC layer.
550        for remote_module in self._create_remote_module_iter(
551            dst_worker1_name, modes=[ModuleCreationMode.MODULE_CTOR]
552        ):
553            remote_module2 = rpc.rpc_sync(
554                dst_worker2_name,
555                RemoteModule.init_from_module_rref,
556                (dst_worker2_name, remote_module.get_module_rref()),
557            )
558
559            args = (torch.ones(1), 2, "3")
560            ret1 = rpc.rpc_sync(
561                dst_worker1_name, remote_forward, (remote_module, args)
562            )
563            ret2 = rpc.rpc_sync(
564                dst_worker2_name, remote_forward, (remote_module2, args)
565            )
566            self.assertEqual(ret2, ret2)
567
568
569class CudaRemoteModuleTest(CommonRemoteModuleTest):
570    @skip_if_lt_x_gpu(1)
571    @dist_utils.dist_init
572    def test_valid_device(self):
573        if self.rank != 0:
574            return
575        dst_rank = (self.rank + 1) % self.world_size
576        dst_worker_name = dist_utils.worker_name(dst_rank)
577
578        for remote_module in self._create_remote_module_iter(
579            f"{dst_worker_name}/cuda:0", modes=[ModuleCreationMode.MODULE_CTOR]
580        ):
581            device = rpc.rpc_sync(
582                dst_worker_name, remote_device, (remote_module.module_rref,)
583            )
584            self.assertEqual(device.type, "cuda")
585            self.assertEqual(device.index, 0)
586
587        # Test rank works as well.
588        for remote_module in self._create_remote_module_iter(
589            f"rank:{dst_rank}/cuda:0", modes=[ModuleCreationMode.MODULE_CTOR]
590        ):
591            device = rpc.rpc_sync(
592                dst_worker_name, remote_device, (remote_module.module_rref,)
593            )
594            self.assertEqual(device.type, "cuda")
595            self.assertEqual(device.index, 0)
596
597    @skip_if_lt_x_gpu(1)
598    @dist_utils.dist_init
599    def test_invalid_devices(self):
600        if self.rank != 0:
601            return
602        dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
603
604        with self.assertRaisesRegex(
605            RuntimeError,
606            r"Expected one of .+ device type at start of device string",
607        ):
608            [
609                m.forward()
610                for m in self._create_remote_module_iter(
611                    f"{dst_worker_name}/foo",
612                    modes=[ModuleCreationMode.MODULE_CTOR],
613                )
614            ]
615
616        with self.assertRaisesRegex(
617            RuntimeError, r"CUDA error: invalid device ordinal"
618        ):
619            [
620                m.forward()
621                for m in self._create_remote_module_iter(
622                    f"{dst_worker_name}/cuda:100",
623                    modes=[ModuleCreationMode.MODULE_CTOR],
624                )
625            ]
626
627        with self.assertRaisesRegex(RuntimeError, r"Invalid device string: 'cpu2'"):
628            [
629                m.forward()
630                for m in self._create_remote_module_iter(
631                    f"{dst_worker_name}/cpu2",
632                    modes=[ModuleCreationMode.MODULE_CTOR],
633                )
634            ]
635
636        with self.assertRaisesRegex(RuntimeError, r"Device string must not be empty"):
637            [
638                m.forward()
639                for m in self._create_remote_module_iter(
640                    f"{dst_worker_name}/",
641                    modes=[ModuleCreationMode.MODULE_CTOR],
642                )
643            ]
644
645        with self.assertRaisesRegex(
646            ValueError,
647            r"Could not parse remote_device: worker1/cuda:0/cuda:1. The valid format is '<workername>/<device>'",
648        ):
649            [
650                m.forward()
651                for m in self._create_remote_module_iter(
652                    f"{dst_worker_name}/cuda:0/cuda:1",
653                    modes=[ModuleCreationMode.MODULE_CTOR],
654                )
655            ]
656
657        with self.assertRaisesRegex(
658            ValueError,
659            r"Could not parse remote_device: /. The valid format is '<workername>/<device>'",
660        ):
661            [
662                m.forward()
663                for m in self._create_remote_module_iter(
664                    "/",
665                    modes=[ModuleCreationMode.MODULE_CTOR],
666                )
667            ]
668
669        with self.assertRaisesRegex(
670            ValueError,
671            r"Could not parse remote_device: /cuda:0. The valid format is '<workername>/<device>'",
672        ):
673            [
674                m.forward()
675                for m in self._create_remote_module_iter(
676                    "/cuda:0",
677                    modes=[ModuleCreationMode.MODULE_CTOR],
678                )
679            ]
680
681    @skip_if_lt_x_gpu(1)
682    @dist_utils.dist_init
683    def test_input_moved_to_cuda_device(self):
684        if self.rank != 0:
685            return
686        dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
687
688        # These two CPU tensors (in args and kwargs) should be implicitly moved to an appropriate cuda device.
689        t1 = torch.ones(1)
690        args = (t1, 2)
691        t2 = t1 * 2
692        kwargs = dict(word=t2)
693
694        # Only test Python nn.Module, because script module methods don't support taking kwargs.
695        for remote_module in self._create_remote_module_iter(
696            f"{dst_worker_name}/cuda:0", modes=[ModuleCreationMode.MODULE_CTOR]
697        ):
698            ret_fut = remote_module.forward_async(*args, **kwargs)
699            ret = ret_fut.wait()
700            self.assertEqual(ret, tuple(reversed(args + (t2,))))
701            # TODO: Once the RPC backend can support directly sending GPU tensors, the expected device type should be "cuda:0".
702            self.assertEqual(ret[0].device.type, "cpu")
703            self.assertEqual(ret[2].device.type, "cpu")
704
705            ret = remote_module.forward(*args, **kwargs)
706            self.assertEqual(ret, tuple(reversed(args + (t2,))))
707            # TODO: Once the RPC backend can support directly sending GPU tensors, the expected device type should be "cuda:0".
708            self.assertEqual(ret[0].device.type, "cpu")
709            self.assertEqual(ret[2].device.type, "cpu")
710
711    @skip_if_lt_x_gpu(1)
712    @dist_utils.dist_init
713    def test_input_moved_to_cuda_device_script(self):
714        if self.rank != 0:
715            return
716        dst_worker_name = dist_utils.worker_name((self.rank + 1) % self.world_size)
717
718        scripted_remote_module = next(
719            self._create_remote_module_iter(
720                f"{dst_worker_name}/cuda:0",
721                modes=[ModuleCreationMode.MODULE_CTOR_WITH_INTERFACE],
722            )
723        )
724
725        @torch.jit.script
726        def run_forward(scripted_remote_module: MyModuleInterface):
727            ret = scripted_remote_module.forward(torch.ones(1), 2, "3")
728            return ret
729
730        ret = run_forward(scripted_remote_module)
731
732        self.assertEqual(ret, ("3", 2, torch.ones(1)))
733        # TODO: Once the RPC backend can support directly sending GPU tensors, the expected device type should be "cuda:0".
734        self.assertEqual(ret[2].device.type, "cpu")
735