xref: /aosp_15_r20/external/pytorch/test/distributed/nn/jit/test_instantiator.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#!/usr/bin/env python3
2# Owner(s): ["oncall: distributed"]
3
4import sys
5from pathlib import Path
6from typing import Tuple
7
8import torch
9import torch.distributed as dist
10from torch import nn, Tensor
11
12
13if not dist.is_available():
14    print("Distributed not available, skipping tests", file=sys.stderr)
15    sys.exit(0)
16
17from torch.distributed.nn.jit import instantiator
18from torch.testing._internal.common_utils import run_tests, TestCase
19
20
21@torch.jit.interface
22class MyModuleInterface:
23    def forward(
24        self, tensor: Tensor, number: int, word: str = "default"
25    ) -> Tuple[Tensor, int, str]:
26        pass
27
28
29class MyModule(nn.Module):
30    pass
31
32
33def create_module():
34    return MyModule()
35
36
37class TestInstantiator(TestCase):
38    def test_get_arg_return_types_from_interface(self):
39        (
40            args_str,
41            arg_types_str,
42            return_type_str,
43        ) = instantiator.get_arg_return_types_from_interface(MyModuleInterface)
44        self.assertEqual(args_str, "tensor, number, word")
45        self.assertEqual(arg_types_str, "tensor: Tensor, number: int, word: str")
46        self.assertEqual(return_type_str, "Tuple[Tensor, int, str]")
47
48    def test_instantiate_scripted_remote_module_template(self):
49        dir_path = Path(instantiator.INSTANTIATED_TEMPLATE_DIR_PATH)
50
51        # Cleanup.
52        file_paths = dir_path.glob(f"{instantiator._FILE_PREFIX}*.py")
53        for file_path in file_paths:
54            file_path.unlink()
55
56        # Check before run.
57        file_paths = dir_path.glob(f"{instantiator._FILE_PREFIX}*.py")
58        num_files_before = len(list(file_paths))
59        self.assertEqual(num_files_before, 0)
60
61        generated_module = instantiator.instantiate_scriptable_remote_module_template(
62            MyModuleInterface
63        )
64        self.assertTrue(hasattr(generated_module, "_remote_forward"))
65        self.assertTrue(hasattr(generated_module, "_generated_methods"))
66
67        # Check after run.
68        file_paths = dir_path.glob(f"{instantiator._FILE_PREFIX}*.py")
69        num_files_after = len(list(file_paths))
70        self.assertEqual(num_files_after, 1)
71
72    def test_instantiate_non_scripted_remote_module_template(self):
73        dir_path = Path(instantiator.INSTANTIATED_TEMPLATE_DIR_PATH)
74
75        # Cleanup.
76        file_paths = dir_path.glob(f"{instantiator._FILE_PREFIX}*.py")
77        for file_path in file_paths:
78            file_path.unlink()
79
80        # Check before run.
81        file_paths = dir_path.glob(f"{instantiator._FILE_PREFIX}*.py")
82        num_files_before = len(list(file_paths))
83        self.assertEqual(num_files_before, 0)
84
85        generated_module = (
86            instantiator.instantiate_non_scriptable_remote_module_template()
87        )
88        self.assertTrue(hasattr(generated_module, "_remote_forward"))
89        self.assertTrue(hasattr(generated_module, "_generated_methods"))
90
91        # Check after run.
92        file_paths = dir_path.glob(f"{instantiator._FILE_PREFIX}*.py")
93        num_files_after = len(list(file_paths))
94        self.assertEqual(num_files_after, 1)
95
96
97if __name__ == "__main__":
98    run_tests()
99