1#!/usr/bin/env python3 2# Owner(s): ["oncall: distributed"] 3 4import sys 5from pathlib import Path 6from typing import Tuple 7 8import torch 9import torch.distributed as dist 10from torch import nn, Tensor 11 12 13if not dist.is_available(): 14 print("Distributed not available, skipping tests", file=sys.stderr) 15 sys.exit(0) 16 17from torch.distributed.nn.jit import instantiator 18from torch.testing._internal.common_utils import run_tests, TestCase 19 20 21@torch.jit.interface 22class MyModuleInterface: 23 def forward( 24 self, tensor: Tensor, number: int, word: str = "default" 25 ) -> Tuple[Tensor, int, str]: 26 pass 27 28 29class MyModule(nn.Module): 30 pass 31 32 33def create_module(): 34 return MyModule() 35 36 37class TestInstantiator(TestCase): 38 def test_get_arg_return_types_from_interface(self): 39 ( 40 args_str, 41 arg_types_str, 42 return_type_str, 43 ) = instantiator.get_arg_return_types_from_interface(MyModuleInterface) 44 self.assertEqual(args_str, "tensor, number, word") 45 self.assertEqual(arg_types_str, "tensor: Tensor, number: int, word: str") 46 self.assertEqual(return_type_str, "Tuple[Tensor, int, str]") 47 48 def test_instantiate_scripted_remote_module_template(self): 49 dir_path = Path(instantiator.INSTANTIATED_TEMPLATE_DIR_PATH) 50 51 # Cleanup. 52 file_paths = dir_path.glob(f"{instantiator._FILE_PREFIX}*.py") 53 for file_path in file_paths: 54 file_path.unlink() 55 56 # Check before run. 57 file_paths = dir_path.glob(f"{instantiator._FILE_PREFIX}*.py") 58 num_files_before = len(list(file_paths)) 59 self.assertEqual(num_files_before, 0) 60 61 generated_module = instantiator.instantiate_scriptable_remote_module_template( 62 MyModuleInterface 63 ) 64 self.assertTrue(hasattr(generated_module, "_remote_forward")) 65 self.assertTrue(hasattr(generated_module, "_generated_methods")) 66 67 # Check after run. 68 file_paths = dir_path.glob(f"{instantiator._FILE_PREFIX}*.py") 69 num_files_after = len(list(file_paths)) 70 self.assertEqual(num_files_after, 1) 71 72 def test_instantiate_non_scripted_remote_module_template(self): 73 dir_path = Path(instantiator.INSTANTIATED_TEMPLATE_DIR_PATH) 74 75 # Cleanup. 76 file_paths = dir_path.glob(f"{instantiator._FILE_PREFIX}*.py") 77 for file_path in file_paths: 78 file_path.unlink() 79 80 # Check before run. 81 file_paths = dir_path.glob(f"{instantiator._FILE_PREFIX}*.py") 82 num_files_before = len(list(file_paths)) 83 self.assertEqual(num_files_before, 0) 84 85 generated_module = ( 86 instantiator.instantiate_non_scriptable_remote_module_template() 87 ) 88 self.assertTrue(hasattr(generated_module, "_remote_forward")) 89 self.assertTrue(hasattr(generated_module, "_generated_methods")) 90 91 # Check after run. 92 file_paths = dir_path.glob(f"{instantiator._FILE_PREFIX}*.py") 93 num_files_after = len(list(file_paths)) 94 self.assertEqual(num_files_after, 1) 95 96 97if __name__ == "__main__": 98 run_tests() 99