1# Owner(s): ["module: unknown"] 2 3import collections 4import unittest 5 6import torch 7from torch.testing._internal.common_utils import run_tests, TEST_WITH_ASAN, TestCase 8 9 10try: 11 import psutil 12 13 HAS_PSUTIL = True 14except ModuleNotFoundError: 15 HAS_PSUTIL = False 16 psutil = None 17 18 19device = torch.device("cpu") 20 21 22class Network(torch.nn.Module): 23 maxp1 = torch.nn.MaxPool2d(1, 1) 24 25 def forward(self, x): 26 return self.maxp1(x) 27 28 29@unittest.skipIf(not HAS_PSUTIL, "Requires psutil to run") 30@unittest.skipIf(TEST_WITH_ASAN, "Cannot test with ASAN") 31class TestOpenMP_ParallelFor(TestCase): 32 batch = 20 33 channels = 1 34 side_dim = 80 35 x = torch.randn([batch, channels, side_dim, side_dim], device=device) 36 model = Network() 37 38 def func(self, runs): 39 p = psutil.Process() 40 # warm up for 5 runs, then things should be stable for the last 5 41 last_rss = collections.deque(maxlen=5) 42 for n in range(10): 43 for i in range(runs): 44 self.model(self.x) 45 last_rss.append(p.memory_info().rss) 46 return last_rss 47 48 def func_rss(self, runs): 49 last_rss = list(self.func(runs)) 50 # Check that the sequence is not strictly increasing 51 is_increasing = True 52 for idx in range(len(last_rss)): 53 if idx == 0: 54 continue 55 is_increasing = is_increasing and (last_rss[idx] > last_rss[idx - 1]) 56 self.assertTrue( 57 not is_increasing, msg=f"memory usage is increasing, {str(last_rss)}" 58 ) 59 60 def test_one_thread(self): 61 """Make sure there is no memory leak with one thread: issue gh-32284""" 62 torch.set_num_threads(1) 63 self.func_rss(300) 64 65 def test_n_threads(self): 66 """Make sure there is no memory leak with many threads""" 67 ncores = min(5, psutil.cpu_count(logical=False)) 68 torch.set_num_threads(ncores) 69 self.func_rss(300) 70 71 72if __name__ == "__main__": 73 run_tests() 74