1# mypy: allow-untyped-defs 2 3import os 4from abc import ABC, abstractmethod 5 6import torch.testing._internal.dist_utils 7 8 9class RpcAgentTestFixture(ABC): 10 @property 11 def world_size(self) -> int: 12 return 4 13 14 @property 15 def init_method(self): 16 use_tcp_init = os.environ.get("RPC_INIT_WITH_TCP", None) 17 if use_tcp_init == "1": 18 master_addr = os.environ["MASTER_ADDR"] 19 master_port = os.environ["MASTER_PORT"] 20 return f"tcp://{master_addr}:{master_port}" 21 else: 22 return self.file_init_method 23 24 @property 25 def file_init_method(self): 26 return torch.testing._internal.dist_utils.INIT_METHOD_TEMPLATE.format( 27 file_name=self.file_name 28 ) 29 30 @property 31 @abstractmethod 32 def rpc_backend(self): 33 pass 34 35 @property 36 @abstractmethod 37 def rpc_backend_options(self): 38 pass 39 40 def setup_fault_injection(self, faulty_messages, messages_to_delay): # noqa: B027 41 """Method used by dist_init to prepare the faulty agent. 42 43 Does nothing for other agents. 44 """ 45 46 # Shutdown sequence is not well defined, so we may see any of the following 47 # errors when running tests that simulate errors via a shutdown on the 48 # remote end. 49 @abstractmethod 50 def get_shutdown_error_regex(self): 51 """ 52 Return various error message we may see from RPC agents while running 53 tests that check for failures. This function is used to match against 54 possible errors to ensure failures were raised properly. 55 """ 56 57 @abstractmethod 58 def get_timeout_error_regex(self): 59 """ 60 Returns a partial string indicating the error we should receive when an 61 RPC has timed out. Useful for use with assertRaisesRegex() to ensure we 62 have the right errors during timeout. 63 """ 64