1 // 2 // Copyright © 2017 Arm Ltd and Contributors. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #include "DriverTestHelpers.hpp" 7 #include <log/log.h> 8 9 #include <armnn/src/armnn/OptimizedNetworkImpl.hpp> 10 11 #include <fstream> 12 #include <memory> 13 #include <armnn/INetwork.hpp> 14 15 #include <armnnUtils/Filesystem.hpp> 16 17 using namespace android; 18 using namespace android::nn; 19 using namespace android::hardware; 20 using namespace armnn_driver; 21 22 namespace armnn 23 { 24 25 class Graph 26 { 27 public: 28 Graph(Graph&& graph) = default; 29 }; 30 31 class MockOptimizedNetworkImpl final : public ::armnn::OptimizedNetworkImpl 32 { 33 public: MockOptimizedNetworkImpl(const std::string & mockSerializedContent,std::unique_ptr<armnn::Graph>)34 MockOptimizedNetworkImpl(const std::string& mockSerializedContent, std::unique_ptr<armnn::Graph>) 35 : ::armnn::OptimizedNetworkImpl(nullptr) 36 , m_MockSerializedContent(mockSerializedContent) 37 {} ~MockOptimizedNetworkImpl()38 ~MockOptimizedNetworkImpl() {} 39 PrintGraph()40 ::armnn::Status PrintGraph() override { return ::armnn::Status::Failure; } SerializeToDot(std::ostream & stream) const41 ::armnn::Status SerializeToDot(std::ostream& stream) const override 42 { 43 stream << m_MockSerializedContent; 44 45 return stream.good() ? ::armnn::Status::Success : ::armnn::Status::Failure; 46 } 47 GetGuid() const48 ::arm::pipe::ProfilingGuid GetGuid() const final { return ::arm::pipe::ProfilingGuid(0); } 49 UpdateMockSerializedContent(const std::string & mockSerializedContent)50 void UpdateMockSerializedContent(const std::string& mockSerializedContent) 51 { 52 this->m_MockSerializedContent = mockSerializedContent; 53 } 54 55 private: 56 std::string m_MockSerializedContent; 57 }; 58 59 60 } // armnn namespace 61 62 63 // The following are helpers for writing unit tests for the driver. 64 namespace 65 { 66 67 struct ExportNetworkGraphFixture 68 { 69 public: 70 // Setup: set the output dump directory and an empty dummy model (as only its memory address is used). 71 // Defaulting the output dump directory to "/data" because it should exist and be writable in all deployments. ExportNetworkGraphFixture__anon5503157d0111::ExportNetworkGraphFixture72 ExportNetworkGraphFixture() 73 : ExportNetworkGraphFixture("/data") 74 {} 75 ExportNetworkGraphFixture__anon5503157d0111::ExportNetworkGraphFixture76 ExportNetworkGraphFixture(const std::string& requestInputsAndOutputsDumpDir) 77 : m_RequestInputsAndOutputsDumpDir(requestInputsAndOutputsDumpDir), m_FileName(), m_FileStream() 78 { 79 // Set the name of the output .dot file. 80 // NOTE: the export now uses a time stamp to name the file so we 81 // can't predict ahead of time what the file name will be. 82 std::string timestamp = "dummy"; 83 m_FileName = m_RequestInputsAndOutputsDumpDir / (timestamp + "_networkgraph.dot"); 84 } 85 86 // Teardown: delete the dump file regardless of the outcome of the tests. ~ExportNetworkGraphFixture__anon5503157d0111::ExportNetworkGraphFixture87 ~ExportNetworkGraphFixture() 88 { 89 // Close the file stream. 90 m_FileStream.close(); 91 92 // Ignore any error (such as file not found). 93 (void) remove(m_FileName.c_str()); 94 } 95 FileExists__anon5503157d0111::ExportNetworkGraphFixture96 bool FileExists() 97 { 98 // Close any file opened in a previous session. 99 if (m_FileStream.is_open()) 100 { 101 m_FileStream.close(); 102 } 103 104 if (m_FileName.empty()) 105 { 106 return false; 107 } 108 109 // Open the file. 110 m_FileStream.open(m_FileName, std::ifstream::in); 111 112 // Check that the file is open. 113 if (!m_FileStream.is_open()) 114 { 115 return false; 116 } 117 118 // Check that the stream is readable. 119 return m_FileStream.good(); 120 } 121 GetFileContent__anon5503157d0111::ExportNetworkGraphFixture122 std::string GetFileContent() 123 { 124 // Check that the stream is readable. 125 if (!m_FileStream.good()) 126 { 127 return ""; 128 } 129 130 // Get all the contents of the file. 131 return std::string((std::istreambuf_iterator<char>(m_FileStream)), 132 (std::istreambuf_iterator<char>())); 133 } 134 135 fs::path m_RequestInputsAndOutputsDumpDir; 136 fs::path m_FileName; 137 138 private: 139 std::ifstream m_FileStream; 140 }; 141 142 143 } // namespace 144 145 DOCTEST_TEST_SUITE("UtilsTests") 146 { 147 148 DOCTEST_TEST_CASE("ExportToEmptyDirectory") 149 { 150 // Set the fixture for this test. 151 ExportNetworkGraphFixture fixture(""); 152 153 // Set a mock content for the optimized network. 154 std::string mockSerializedContent = "This is a mock serialized content."; 155 156 // Set a mock optimized network. 157 std::unique_ptr<armnn::Graph> graphPtr; 158 159 std::unique_ptr<::armnn::OptimizedNetworkImpl> mockImpl( 160 new armnn::MockOptimizedNetworkImpl(mockSerializedContent, std::move(graphPtr))); 161 ::armnn::IOptimizedNetwork mockOptimizedNetwork(std::move(mockImpl)); 162 163 // Export the mock optimized network. 164 fixture.m_FileName = armnn_driver::ExportNetworkGraphToDotFile(mockOptimizedNetwork, 165 fixture.m_RequestInputsAndOutputsDumpDir); 166 167 // Check that the output file does not exist. 168 DOCTEST_CHECK(!fixture.FileExists()); 169 } 170 171 DOCTEST_TEST_CASE("ExportNetwork") 172 { 173 // Set the fixture for this test. 174 ExportNetworkGraphFixture fixture; 175 176 // Set a mock content for the optimized network. 177 std::string mockSerializedContent = "This is a mock serialized content."; 178 179 // Set a mock optimized network. 180 std::unique_ptr<armnn::Graph> graphPtr; 181 182 std::unique_ptr<::armnn::OptimizedNetworkImpl> mockImpl( 183 new armnn::MockOptimizedNetworkImpl(mockSerializedContent, std::move(graphPtr))); 184 ::armnn::IOptimizedNetwork mockOptimizedNetwork(std::move(mockImpl)); 185 186 187 // Export the mock optimized network. 188 fixture.m_FileName = armnn_driver::ExportNetworkGraphToDotFile(mockOptimizedNetwork, 189 fixture.m_RequestInputsAndOutputsDumpDir); 190 191 // Check that the output file exists and that it has the correct name. 192 DOCTEST_CHECK(fixture.FileExists()); 193 194 // Check that the content of the output file matches the mock content. 195 DOCTEST_CHECK(fixture.GetFileContent() == mockSerializedContent); 196 } 197 198 DOCTEST_TEST_CASE("ExportNetworkOverwriteFile") 199 { 200 // Set the fixture for this test. 201 ExportNetworkGraphFixture fixture; 202 203 // Set a mock content for the optimized network. 204 std::string mockSerializedContent = "This is a mock serialized content."; 205 206 // Set a mock optimized network. 207 std::unique_ptr<armnn::Graph> graphPtr; 208 209 std::unique_ptr<::armnn::OptimizedNetworkImpl> mockImpl( 210 new armnn::MockOptimizedNetworkImpl(mockSerializedContent, std::move(graphPtr))); 211 ::armnn::IOptimizedNetwork mockOptimizedNetwork(std::move(mockImpl)); 212 213 // Export the mock optimized network. 214 fixture.m_FileName = armnn_driver::ExportNetworkGraphToDotFile(mockOptimizedNetwork, 215 fixture.m_RequestInputsAndOutputsDumpDir); 216 217 // Check that the output file exists and that it has the correct name. 218 DOCTEST_CHECK(fixture.FileExists()); 219 220 // Check that the content of the output file matches the mock content. 221 DOCTEST_CHECK(fixture.GetFileContent() == mockSerializedContent); 222 223 // Update the mock serialized content of the network. 224 mockSerializedContent = "This is ANOTHER mock serialized content!"; 225 std::unique_ptr<armnn::Graph> graphPtr2; 226 std::unique_ptr<::armnn::OptimizedNetworkImpl> mockImpl2( 227 new armnn::MockOptimizedNetworkImpl(mockSerializedContent, std::move(graphPtr2))); 228 static_cast<armnn::MockOptimizedNetworkImpl*>(mockImpl2.get())->UpdateMockSerializedContent(mockSerializedContent); 229 ::armnn::IOptimizedNetwork mockOptimizedNetwork2(std::move(mockImpl2)); 230 231 // Export the mock optimized network. 232 fixture.m_FileName = armnn_driver::ExportNetworkGraphToDotFile(mockOptimizedNetwork2, 233 fixture.m_RequestInputsAndOutputsDumpDir); 234 235 // Check that the output file still exists and that it has the correct name. 236 DOCTEST_CHECK(fixture.FileExists()); 237 238 // Check that the content of the output file matches the mock content. 239 DOCTEST_CHECK(fixture.GetFileContent() == mockSerializedContent); 240 } 241 242 DOCTEST_TEST_CASE("ExportMultipleNetworks") 243 { 244 // Set the fixtures for this test. 245 ExportNetworkGraphFixture fixture1; 246 ExportNetworkGraphFixture fixture2; 247 ExportNetworkGraphFixture fixture3; 248 249 // Set a mock content for the optimized network. 250 std::string mockSerializedContent = "This is a mock serialized content."; 251 252 // Set a mock optimized network. 253 std::unique_ptr<armnn::Graph> graphPtr; 254 255 std::unique_ptr<::armnn::OptimizedNetworkImpl> mockImpl( 256 new armnn::MockOptimizedNetworkImpl(mockSerializedContent, std::move(graphPtr))); 257 ::armnn::IOptimizedNetwork mockOptimizedNetwork(std::move(mockImpl)); 258 259 // Export the mock optimized network. 260 fixture1.m_FileName = armnn_driver::ExportNetworkGraphToDotFile(mockOptimizedNetwork, 261 fixture1.m_RequestInputsAndOutputsDumpDir); 262 263 // Check that the output file exists and that it has the correct name. 264 DOCTEST_CHECK(fixture1.FileExists()); 265 266 // Check that the content of the output file matches the mock content. 267 DOCTEST_CHECK(fixture1.GetFileContent() == mockSerializedContent); 268 269 // Export the mock optimized network. 270 fixture2.m_FileName = armnn_driver::ExportNetworkGraphToDotFile(mockOptimizedNetwork, 271 fixture2.m_RequestInputsAndOutputsDumpDir); 272 273 // Check that the output file exists and that it has the correct name. 274 DOCTEST_CHECK(fixture2.FileExists()); 275 276 // Check that the content of the output file matches the mock content. 277 DOCTEST_CHECK(fixture2.GetFileContent() == mockSerializedContent); 278 279 // Export the mock optimized network. 280 fixture3.m_FileName = armnn_driver::ExportNetworkGraphToDotFile(mockOptimizedNetwork, 281 fixture3.m_RequestInputsAndOutputsDumpDir); 282 // Check that the output file exists and that it has the correct name. 283 DOCTEST_CHECK(fixture3.FileExists()); 284 285 // Check that the content of the output file matches the mock content. 286 DOCTEST_CHECK(fixture3.GetFileContent() == mockSerializedContent); 287 } 288 289 } 290