1# Owner(s): ["oncall: distributed"] 2 3import random 4import sys 5import unittest 6from collections import OrderedDict 7from dataclasses import dataclass 8from typing import List 9 10import torch 11import torch.nn as nn 12from torch import distributed as dist 13from torch.distributed.utils import _apply_to_tensors, _replace_by_prefix 14from torch.testing._internal.common_utils import ( 15 instantiate_parametrized_tests, 16 parametrize, 17 run_tests, 18 subtest, 19 TEST_WITH_DEV_DBG_ASAN, 20 TestCase, 21) 22 23 24if not dist.is_available(): 25 print("Distributed not available, skipping tests", file=sys.stderr) 26 sys.exit(0) 27 28if TEST_WITH_DEV_DBG_ASAN: 29 print( 30 "Skip dev-asan as torch + multiprocessing spawn have known issues", 31 file=sys.stderr, 32 ) 33 sys.exit(0) 34 35 36class TestUtils(TestCase): 37 @parametrize( 38 "devices", [["cpu"], ["cuda"], subtest(["cpu", "cuda"], name="cpu_cuda")] 39 ) 40 def test_apply_to_tensors(self, devices): 41 if "cuda" in devices and ( 42 not torch.cuda.is_available() or torch.cuda.device_count() < 1 43 ): 44 raise unittest.SkipTest("Skipped due to lack of GPU") 45 46 expected = 0 47 48 def get_a_tensor(): 49 """Return a random tensor on random device.""" 50 dev = random.choice(devices) 51 shape = random.choice(((1), (2, 3), (4, 5, 6), (7, 8, 9, 10))) 52 t = torch.rand(shape).to(dev) 53 nonlocal expected 54 expected += t.numel() 55 return t 56 57 @dataclass 58 class NonFrozenDataClass: 59 some_key: str 60 some_float: float 61 some_tensor: List[torch.Tensor] 62 63 @dataclass(frozen=True) 64 class FrozenDataClass: 65 some_key: str 66 some_float: float 67 some_tensor: List[torch.Tensor] 68 69 # create a mixed bag of data. 70 data = [1, "str"] 71 data.append({"key1": get_a_tensor(), "key2": {1: get_a_tensor()}, "key3": 3}) 72 data.insert(0, {"x", get_a_tensor(), get_a_tensor()}) 73 data.append(([1], get_a_tensor(), (1), [get_a_tensor()], {1, 2})) 74 data.append( 75 {"non_frozen_ds": NonFrozenDataClass("some_key", 1.0, [get_a_tensor()])} 76 ) 77 data.append({"frozen_ds": FrozenDataClass("some_key", 1.0, [get_a_tensor()])}) 78 od = OrderedDict() 79 od["k"] = "value" 80 data.append(od) 81 82 total = 0 83 84 def fn(t): 85 nonlocal total 86 total += t.numel() 87 return t 88 89 new_data = _apply_to_tensors(fn, data) 90 self.assertEqual(total, expected) 91 for i, v in enumerate(data): 92 self.assertEqual(type(new_data[i]), type(v)) 93 94 def test_replace_by_prefix(self): 95 state_dict = { 96 "layer.a": torch.tensor(1), 97 "abc.layer.def": torch.tensor(2), 98 "layer.b": torch.tensor(3), 99 } 100 original_state_dict = state_dict.copy() 101 _replace_by_prefix(state_dict, "layer.", "module.layer.") 102 assert state_dict == { 103 "module.layer.a": torch.tensor(1), 104 "abc.layer.def": torch.tensor(2), 105 "module.layer.b": torch.tensor(3), 106 } 107 _replace_by_prefix(state_dict, "module.layer.", "layer.") 108 assert state_dict == original_state_dict 109 110 def test_packed_sequence(self): 111 """Test to ensure RNN packed sequences are modified correctly.""" 112 rnn = nn.RNN(5, 5) 113 114 x = torch.rand((5, 1, 5), dtype=torch.float) 115 seq_length = torch.tensor([4], dtype=torch.int) 116 117 def fill_fn(x): 118 x.fill_(0) 119 120 x = nn.utils.rnn.pack_padded_sequence(x, seq_length) 121 x, h = rnn(x) 122 x = _apply_to_tensors(fill_fn, x) 123 x, _ = nn.utils.rnn.pad_packed_sequence(x) 124 self.assertEqual(torch.sum(x), 0) 125 126 127instantiate_parametrized_tests(TestUtils) 128 129if __name__ == "__main__": 130 run_tests() 131