1import os 2import sys 3 4import torch 5 6 7class Setup: 8 def setup(self): 9 raise NotImplementedError 10 11 def shutdown(self): 12 raise NotImplementedError 13 14 15class FileSetup: 16 path = None 17 18 def shutdown(self): 19 if os.path.exists(self.path): 20 os.remove(self.path) 21 22 23class EvalModeForLoadedModule(FileSetup): 24 path = "dropout_model.pt" 25 26 def setup(self): 27 class Model(torch.jit.ScriptModule): 28 def __init__(self) -> None: 29 super().__init__() 30 self.dropout = torch.nn.Dropout(0.1) 31 32 @torch.jit.script_method 33 def forward(self, x): 34 x = self.dropout(x) 35 return x 36 37 model = Model() 38 model = model.train() 39 model.save(self.path) 40 41 42class SerializationInterop(FileSetup): 43 path = "ivalue.pt" 44 45 def setup(self): 46 ones = torch.ones(2, 2) 47 twos = torch.ones(3, 5) * 2 48 49 value = (ones, twos) 50 51 torch.save(value, self.path, _use_new_zipfile_serialization=True) 52 53 54# See testTorchSaveError in test/cpp/jit/tests.h for usage 55class TorchSaveError(FileSetup): 56 path = "eager_value.pt" 57 58 def setup(self): 59 ones = torch.ones(2, 2) 60 twos = torch.ones(3, 5) * 2 61 62 value = (ones, twos) 63 64 torch.save(value, self.path, _use_new_zipfile_serialization=False) 65 66 67class TorchSaveJitStream_CUDA(FileSetup): 68 path = "saved_stream_model.pt" 69 70 def setup(self): 71 if not torch.cuda.is_available(): 72 return 73 74 class Model(torch.nn.Module): 75 def forward(self): 76 s = torch.cuda.Stream() 77 a = torch.rand(3, 4, device="cuda") 78 b = torch.rand(3, 4, device="cuda") 79 80 with torch.cuda.stream(s): 81 is_stream_s = ( 82 torch.cuda.current_stream(s.device_index()).id() == s.id() 83 ) 84 c = torch.cat((a, b), 0).to("cuda") 85 s.synchronize() 86 return is_stream_s, a, b, c 87 88 model = Model() 89 90 # Script the model and save 91 script_model = torch.jit.script(model) 92 torch.jit.save(script_model, self.path) 93 94 95tests = [ 96 EvalModeForLoadedModule(), 97 SerializationInterop(), 98 TorchSaveError(), 99 TorchSaveJitStream_CUDA(), 100] 101 102 103def setup(): 104 for test in tests: 105 test.setup() 106 107 108def shutdown(): 109 for test in tests: 110 test.shutdown() 111 112 113if __name__ == "__main__": 114 command = sys.argv[1] 115 if command == "setup": 116 setup() 117 elif command == "shutdown": 118 shutdown() 119