xref: /aosp_15_r20/external/pytorch/test/cpp/jit/test_file_format.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/mobile/file_format.h>
2 
3 #include <gtest/gtest.h>
4 
5 #include <sstream>
6 
7 // Tests go in torch::jit
8 namespace torch {
9 namespace jit {
10 
TEST(FileFormatTest,IdentifiesFlatbufferStream)11 TEST(FileFormatTest, IdentifiesFlatbufferStream) {
12   // Create data whose initial bytes look like a Flatbuffer stream.
13   std::stringstream data;
14   data << "abcd" // First four bytes don't matter.
15        << "PTMF" // Magic string.
16        << "efgh"; // Trailing bytes don't matter.
17 
18   // The data should be identified as Flatbuffer.
19   EXPECT_EQ(getFileFormat(data), FileFormat::FlatbufferFileFormat);
20 }
21 
TEST(FileFormatTest,IdentifiesZipStream)22 TEST(FileFormatTest, IdentifiesZipStream) {
23   // Create data whose initial bytes look like a ZIP stream.
24   std::stringstream data;
25   data << "PK\x03\x04" // Magic string.
26        << "abcd" // Trailing bytes don't matter.
27        << "efgh";
28 
29   // The data should be identified as ZIP.
30   EXPECT_EQ(getFileFormat(data), FileFormat::ZipFileFormat);
31 }
32 
TEST(FileFormatTest,FlatbufferTakesPrecedence)33 TEST(FileFormatTest, FlatbufferTakesPrecedence) {
34   // Since the Flatbuffer and ZIP magic bytes are at different offsets,
35   // the same data could be identified as both. Demonstrate that Flatbuffer
36   // takes precedence. (See details in file_format.h)
37   std::stringstream data;
38   data << "PK\x03\x04" // ZIP magic string.
39        << "PTMF" // Flatbuffer magic string.
40        << "abcd"; // Trailing bytes don't matter.
41 
42   // The data should be identified as Flatbuffer.
43   EXPECT_EQ(getFileFormat(data), FileFormat::FlatbufferFileFormat);
44 }
45 
TEST(FileFormatTest,HandlesUnknownStream)46 TEST(FileFormatTest, HandlesUnknownStream) {
47   // Create data that doesn't look like any known format.
48   std::stringstream data;
49   data << "abcd"
50        << "efgh"
51        << "ijkl";
52 
53   // The data should be classified as unknown.
54   EXPECT_EQ(getFileFormat(data), FileFormat::UnknownFileFormat);
55 }
56 
TEST(FileFormatTest,ShortStreamIsUnknown)57 TEST(FileFormatTest, ShortStreamIsUnknown) {
58   // Create data with fewer than kFileFormatHeaderSize (8) bytes.
59   std::stringstream data;
60   data << "ABCD";
61 
62   // The data should be classified as unknown.
63   EXPECT_EQ(getFileFormat(data), FileFormat::UnknownFileFormat);
64 }
65 
TEST(FileFormatTest,EmptyStreamIsUnknown)66 TEST(FileFormatTest, EmptyStreamIsUnknown) {
67   // Create an empty stream.
68   std::stringstream data;
69 
70   // The data should be classified as unknown.
71   EXPECT_EQ(getFileFormat(data), FileFormat::UnknownFileFormat);
72 }
73 
TEST(FileFormatTest,BadStreamIsUnknown)74 TEST(FileFormatTest, BadStreamIsUnknown) {
75   // Create a stream with valid Flatbuffer data.
76   std::stringstream data;
77   data << "abcd"
78        << "PTMF" // Flatbuffer magic string.
79        << "efgh";
80 
81   // Demonstrate that the data would normally be identified as Flatbuffer.
82   EXPECT_EQ(getFileFormat(data), FileFormat::FlatbufferFileFormat);
83 
84   // Mark the stream as bad, and demonstrate that it is in an error state.
85   data.setstate(std::stringstream::badbit);
86   // Demonstrate that the stream is in an error state.
87   EXPECT_FALSE(data.good());
88 
89   // The data should now be classified as unknown.
90   EXPECT_EQ(getFileFormat(data), FileFormat::UnknownFileFormat);
91 }
92 
TEST(FileFormatTest,StreamOffsetIsObservedAndRestored)93 TEST(FileFormatTest, StreamOffsetIsObservedAndRestored) {
94   // Create data with a Flatbuffer header at a non-zero offset into the stream.
95   std::stringstream data;
96   // Add initial padding.
97   data << "PADDING";
98   size_t offset = data.str().size();
99   // Add a valid Flatbuffer header.
100   data << "abcd"
101        << "PTMF" // Flatbuffer magic string.
102        << "efgh";
103   // Seek just after the padding.
104   data.seekg(static_cast<std::stringstream::off_type>(offset), data.beg);
105   // Demonstrate that the stream points to the beginning of the Flatbuffer data,
106   // not to the padding.
107   EXPECT_EQ(data.peek(), 'a');
108 
109   // The data should be identified as Flatbuffer.
110   EXPECT_EQ(getFileFormat(data), FileFormat::FlatbufferFileFormat);
111 
112   // The stream position should be where it was before identification.
113   EXPECT_EQ(offset, data.tellg());
114 }
115 
TEST(FileFormatTest,HandlesMissingFile)116 TEST(FileFormatTest, HandlesMissingFile) {
117   // A missing file should be classified as unknown.
118   EXPECT_EQ(
119       getFileFormat("NON_EXISTENT_FILE_4965c363-44a7-443c-983a-8895eead0277"),
120       FileFormat::UnknownFileFormat);
121 }
122 
123 } // namespace jit
124 } // namespace torch
125