xref: /aosp_15_r20/external/pytorch/test/cpp/jit/tests_setup.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import os
2import sys
3
4import torch
5
6
7class Setup:
8    def setup(self):
9        raise NotImplementedError
10
11    def shutdown(self):
12        raise NotImplementedError
13
14
15class FileSetup:
16    path = None
17
18    def shutdown(self):
19        if os.path.exists(self.path):
20            os.remove(self.path)
21
22
23class EvalModeForLoadedModule(FileSetup):
24    path = "dropout_model.pt"
25
26    def setup(self):
27        class Model(torch.jit.ScriptModule):
28            def __init__(self) -> None:
29                super().__init__()
30                self.dropout = torch.nn.Dropout(0.1)
31
32            @torch.jit.script_method
33            def forward(self, x):
34                x = self.dropout(x)
35                return x
36
37        model = Model()
38        model = model.train()
39        model.save(self.path)
40
41
42class SerializationInterop(FileSetup):
43    path = "ivalue.pt"
44
45    def setup(self):
46        ones = torch.ones(2, 2)
47        twos = torch.ones(3, 5) * 2
48
49        value = (ones, twos)
50
51        torch.save(value, self.path, _use_new_zipfile_serialization=True)
52
53
54# See testTorchSaveError in test/cpp/jit/tests.h for usage
55class TorchSaveError(FileSetup):
56    path = "eager_value.pt"
57
58    def setup(self):
59        ones = torch.ones(2, 2)
60        twos = torch.ones(3, 5) * 2
61
62        value = (ones, twos)
63
64        torch.save(value, self.path, _use_new_zipfile_serialization=False)
65
66
67class TorchSaveJitStream_CUDA(FileSetup):
68    path = "saved_stream_model.pt"
69
70    def setup(self):
71        if not torch.cuda.is_available():
72            return
73
74        class Model(torch.nn.Module):
75            def forward(self):
76                s = torch.cuda.Stream()
77                a = torch.rand(3, 4, device="cuda")
78                b = torch.rand(3, 4, device="cuda")
79
80                with torch.cuda.stream(s):
81                    is_stream_s = (
82                        torch.cuda.current_stream(s.device_index()).id() == s.id()
83                    )
84                    c = torch.cat((a, b), 0).to("cuda")
85                s.synchronize()
86                return is_stream_s, a, b, c
87
88        model = Model()
89
90        # Script the model and save
91        script_model = torch.jit.script(model)
92        torch.jit.save(script_model, self.path)
93
94
95tests = [
96    EvalModeForLoadedModule(),
97    SerializationInterop(),
98    TorchSaveError(),
99    TorchSaveJitStream_CUDA(),
100]
101
102
103def setup():
104    for test in tests:
105        test.setup()
106
107
108def shutdown():
109    for test in tests:
110        test.shutdown()
111
112
113if __name__ == "__main__":
114    command = sys.argv[1]
115    if command == "setup":
116        setup()
117    elif command == "shutdown":
118        shutdown()
119