# Owner(s): ["oncall: mobile"] import io import tempfile import unittest import torch import torch.utils.show_pickle from torch.testing._internal.common_utils import IS_WINDOWS, run_tests, TestCase class TestShowPickle(TestCase): @unittest.skipIf(IS_WINDOWS, "Can't re-open temp file on Windows") def test_scripted_model(self): class MyCoolModule(torch.nn.Module): def __init__(self, weight): super().__init__() self.weight = weight def forward(self, x): return x * self.weight m = torch.jit.script(MyCoolModule(torch.tensor([2.0]))) with tempfile.NamedTemporaryFile() as tmp: torch.jit.save(m, tmp) tmp.flush() buf = io.StringIO() torch.utils.show_pickle.main( ["", tmp.name + "@*/data.pkl"], output_stream=buf ) output = buf.getvalue() self.assertRegex(output, "MyCoolModule") self.assertRegex(output, "weight") if __name__ == "__main__": run_tests()