xref: /aosp_15_r20/external/armnn/src/profiling/test/ProfilingTests.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2019 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "ProfilingTests.hpp"
7 #include "ProfilingTestUtils.hpp"
8 #include <Runtime.hpp>
9 #include <ArmNNProfilingServiceInitialiser.hpp>
10 
11 #include <client/src/CommandHandler.hpp>
12 #include <client/src/ConnectionAcknowledgedCommandHandler.hpp>
13 #include <client/src/PeriodicCounterCapture.hpp>
14 #include <client/src/PeriodicCounterSelectionCommandHandler.hpp>
15 #include <client/src/ProfilingStateMachine.hpp>
16 #include <client/src/ProfilingUtils.hpp>
17 #include <client/src/RegisterBackendCounters.hpp>
18 #include <client/src/RequestCounterDirectoryCommandHandler.hpp>
19 #include <client/src/SocketProfilingConnection.hpp>
20 #include <client/src/SendCounterPacket.hpp>
21 #include <client/src/SendThread.hpp>
22 #include <client/src/SendTimelinePacket.hpp>
23 #include <client/src/backends/BackendProfiling.hpp>
24 
25 #include <armnn/Utils.hpp>
26 
27 #include <armnn/profiling/ArmNNProfiling.hpp>
28 
29 #include <client/include/CounterIdMap.hpp>
30 #include <client/include/Holder.hpp>
31 #include <client/include/ICounterValues.hpp>
32 #include <client/include/ProfilingOptions.hpp>
33 
34 #include <common/include/CommandHandlerKey.hpp>
35 #include <common/include/CommandHandlerRegistry.hpp>
36 #include <common/include/CounterDirectory.hpp>
37 #include <common/include/EncodeVersion.hpp>
38 #include <common/include/IgnoreUnused.hpp>
39 #include <common/include/NumericCast.hpp>
40 #include <common/include/Packet.hpp>
41 #include <common/include/PacketVersionResolver.hpp>
42 #include <common/include/SocketConnectionException.hpp>
43 #include <common/include/SwTrace.hpp>
44 
45 #include <doctest/doctest.h>
46 
47 #include <algorithm>
48 #include <cstdint>
49 #include <cstring>
50 #include <iostream>
51 #include <limits>
52 #include <map>
53 #include <random>
54 
55 
56 using namespace arm::pipe;
57 using PacketType = MockProfilingConnection::PacketType;
58 
59 TEST_SUITE("ExternalProfiling")
60 {
61 TEST_CASE("CheckCommandHandlerKeyComparisons")
62 {
63     arm::pipe::CommandHandlerKey testKey1_0(1, 1, 1);
64     arm::pipe::CommandHandlerKey testKey1_1(1, 1, 1);
65     arm::pipe::CommandHandlerKey testKey1_2(1, 2, 1);
66 
67     arm::pipe::CommandHandlerKey testKey0(0, 1, 1);
68     arm::pipe::CommandHandlerKey testKey1(0, 1, 1);
69     arm::pipe::CommandHandlerKey testKey2(0, 1, 1);
70     arm::pipe::CommandHandlerKey testKey3(0, 0, 0);
71     arm::pipe::CommandHandlerKey testKey4(0, 2, 2);
72     arm::pipe::CommandHandlerKey testKey5(0, 0, 2);
73 
74     CHECK(testKey1_0 > testKey0);
75     CHECK(testKey1_0 == testKey1_1);
76     CHECK(testKey1_0 < testKey1_2);
77 
78     CHECK(testKey1 < testKey4);
79     CHECK(testKey1 > testKey3);
80     CHECK(testKey1 <= testKey4);
81     CHECK(testKey1 >= testKey3);
82     CHECK(testKey1 <= testKey2);
83     CHECK(testKey1 >= testKey2);
84     CHECK(testKey1 == testKey2);
85     CHECK(testKey1 == testKey1);
86 
87     CHECK(!(testKey1 == testKey5));
88     CHECK(!(testKey1 != testKey1));
89     CHECK(testKey1 != testKey5);
90 
91     CHECK((testKey1 == testKey2 && testKey2 == testKey1));
92     CHECK((testKey0 == testKey1 && testKey1 == testKey2 && testKey0 == testKey2));
93 
94     CHECK(testKey1.GetPacketId() == 1);
95     CHECK(testKey1.GetVersion() == 1);
96 
97     std::vector<arm::pipe::CommandHandlerKey> vect = {
98         arm::pipe::CommandHandlerKey(0, 0, 1), arm::pipe::CommandHandlerKey(0, 2, 0),
99         arm::pipe::CommandHandlerKey(0, 1, 0), arm::pipe::CommandHandlerKey(0, 2, 1),
100         arm::pipe::CommandHandlerKey(0, 1, 1), arm::pipe::CommandHandlerKey(0, 0, 1),
101         arm::pipe::CommandHandlerKey(0, 2, 0), arm::pipe::CommandHandlerKey(0, 0, 0) };
102 
103     std::sort(vect.begin(), vect.end());
104 
105     std::vector<arm::pipe::CommandHandlerKey> expectedVect = {
106         arm::pipe::CommandHandlerKey(0, 0, 0), arm::pipe::CommandHandlerKey(0, 0, 1),
107         arm::pipe::CommandHandlerKey(0, 0, 1), arm::pipe::CommandHandlerKey(0, 1, 0),
108         arm::pipe::CommandHandlerKey(0, 1, 1), arm::pipe::CommandHandlerKey(0, 2, 0),
109         arm::pipe::CommandHandlerKey(0, 2, 0), arm::pipe::CommandHandlerKey(0, 2, 1) };
110 
111     CHECK(vect == expectedVect);
112 }
113 
114 TEST_CASE("CheckPacketKeyComparisons")
115 {
116     arm::pipe::PacketKey key0(0, 0);
117     arm::pipe::PacketKey key1(0, 0);
118     arm::pipe::PacketKey key2(0, 1);
119     arm::pipe::PacketKey key3(0, 2);
120     arm::pipe::PacketKey key4(1, 0);
121     arm::pipe::PacketKey key5(1, 0);
122     arm::pipe::PacketKey key6(1, 1);
123 
124     CHECK(!(key0 < key1));
125     CHECK(!(key0 > key1));
126     CHECK(key0 <= key1);
127     CHECK(key0 >= key1);
128     CHECK(key0 == key1);
129     CHECK(key0 < key2);
130     CHECK(key2 < key3);
131     CHECK(key3 > key0);
132     CHECK(key4 == key5);
133     CHECK(key4 > key0);
134     CHECK(key5 < key6);
135     CHECK(key5 <= key6);
136     CHECK(key5 != key6);
137 }
138 
139 TEST_CASE("CheckCommandHandler")
140 {
141     LogLevelSwapper logLevelSwapper(arm::pipe::LogSeverity::Fatal);
142 
143     arm::pipe::PacketVersionResolver packetVersionResolver;
144     ProfilingStateMachine profilingStateMachine;
145 
146     TestProfilingConnectionBase testProfilingConnectionBase;
147     TestProfilingConnectionTimeoutError testProfilingConnectionTimeOutError;
148     TestProfilingConnectionArmnnError testProfilingConnectionArmnnError;
149     CounterDirectory counterDirectory;
150     MockBufferManager mockBuffer(1024);
151     SendCounterPacket sendCounterPacket(mockBuffer,
152                                         arm::pipe::ARMNN_SOFTWARE_INFO,
153                                         arm::pipe::ARMNN_SOFTWARE_VERSION,
154                                         arm::pipe::ARMNN_HARDWARE_VERSION);
155     SendThread sendThread(profilingStateMachine, mockBuffer, sendCounterPacket);
156     SendTimelinePacket sendTimelinePacket(mockBuffer);
157     MockProfilingServiceStatus mockProfilingServiceStatus;
158 
159     ConnectionAcknowledgedCommandHandler connectionAcknowledgedCommandHandler(0, 1, 4194304, counterDirectory,
160                                                                               sendCounterPacket, sendTimelinePacket,
161                                                                               profilingStateMachine,
162                                                                               mockProfilingServiceStatus);
163     arm::pipe::CommandHandlerRegistry commandHandlerRegistry;
164 
165     commandHandlerRegistry.RegisterFunctor(&connectionAcknowledgedCommandHandler);
166 
167     profilingStateMachine.TransitionToState(ProfilingState::NotConnected);
168     profilingStateMachine.TransitionToState(ProfilingState::WaitingForAck);
169 
170     CommandHandler commandHandler0(1, true, commandHandlerRegistry, packetVersionResolver);
171 
172     // This should start the command handler thread return the connection ack and put the profiling
173     // service into active state.
174     commandHandler0.Start(testProfilingConnectionBase);
175     // Try to start the send thread many times, it must only start once
176     commandHandler0.Start(testProfilingConnectionBase);
177 
178     // This could take up to 20mSec but we'll check often.
179     for (int i = 0; i < 10; i++)
180     {
181         if (profilingStateMachine.GetCurrentState() == ProfilingState::Active)
182         {
183             break;
184         }
185         std::this_thread::sleep_for(std::chrono::milliseconds(2));
186     }
187 
188     CHECK(profilingStateMachine.GetCurrentState() == ProfilingState::Active);
189 
190     // Close the thread again.
191     commandHandler0.Stop();
192 
193     profilingStateMachine.TransitionToState(ProfilingState::NotConnected);
194     profilingStateMachine.TransitionToState(ProfilingState::WaitingForAck);
195 
196     // In this test we'll simulate a timeout without a connection ack packet being received.
197     // Stop after timeout is set so we expect the command handler to stop almost immediately.
198     CommandHandler commandHandler1(1, true, commandHandlerRegistry, packetVersionResolver);
199 
200     commandHandler1.Start(testProfilingConnectionTimeOutError);
201     // Wait until we know a timeout exception has been sent at least once.
202     for (int i = 0; i < 10; i++)
203     {
204         if (testProfilingConnectionTimeOutError.ReadCalledCount())
205         {
206             break;
207         }
208         std::this_thread::sleep_for(std::chrono::milliseconds(2));
209     }
210 
211     // The command handler loop should have stopped after the timeout.
212     // wait for the timeout exception to be processed and the loop to break.
213     uint32_t timeout   = 50;
214     uint32_t timeSlept = 0;
215     while (commandHandler1.IsRunning())
216     {
217         if (timeSlept >= timeout)
218         {
219             FAIL("Timeout: The command handler loop did not stop after the timeout");
220         }
221         std::this_thread::sleep_for(std::chrono::milliseconds(1));
222         timeSlept ++;
223     }
224 
225     commandHandler1.Stop();
226     // The state machine should never have received the ack so will still be in WaitingForAck.
227     CHECK(profilingStateMachine.GetCurrentState() == ProfilingState::WaitingForAck);
228 
229     // Now try sending a bad connection acknowledged packet
230     TestProfilingConnectionBadAckPacket testProfilingConnectionBadAckPacket;
231     commandHandler1.Start(testProfilingConnectionBadAckPacket);
232     commandHandler1.Stop();
233     // This should also not change the state machine
234     CHECK(profilingStateMachine.GetCurrentState() == ProfilingState::WaitingForAck);
235 
236     // Disable stop after timeout and now commandHandler1 should persist after a timeout
237     commandHandler1.SetStopAfterTimeout(false);
238     // Restart the thread.
239     commandHandler1.Start(testProfilingConnectionTimeOutError);
240 
241     // Wait for at the three timeouts and the ack to be sent.
242     for (int i = 0; i < 10; i++)
243     {
244         if (testProfilingConnectionTimeOutError.ReadCalledCount() > 3)
245         {
246             break;
247         }
248         std::this_thread::sleep_for(std::chrono::milliseconds(2));
249     }
250     commandHandler1.Stop();
251 
252     // Even after the 3 exceptions the ack packet should have transitioned the command handler to active.
253     CHECK(profilingStateMachine.GetCurrentState() == ProfilingState::Active);
254 
255     // A command handler that gets exceptions other than timeouts should keep going.
256     CommandHandler commandHandler2(1, false, commandHandlerRegistry, packetVersionResolver);
257 
258     commandHandler2.Start(testProfilingConnectionArmnnError);
259 
260     // Wait for two exceptions to be thrown.
261     for (int i = 0; i < 10; i++)
262     {
263         if (testProfilingConnectionTimeOutError.ReadCalledCount() >= 2)
264         {
265             break;
266         }
267         std::this_thread::sleep_for(std::chrono::milliseconds(2));
268     }
269 
270     CHECK(commandHandler2.IsRunning());
271     commandHandler2.Stop();
272 }
273 
274 TEST_CASE("CheckEncodeVersion")
275 {
276     arm::pipe::Version version1(12);
277 
278     CHECK(version1.GetMajor() == 0);
279     CHECK(version1.GetMinor() == 0);
280     CHECK(version1.GetPatch() == 12);
281 
282     arm::pipe::Version version2(4108);
283 
284     CHECK(version2.GetMajor() == 0);
285     CHECK(version2.GetMinor() == 1);
286     CHECK(version2.GetPatch() == 12);
287 
288     arm::pipe::Version version3(4198412);
289 
290     CHECK(version3.GetMajor() == 1);
291     CHECK(version3.GetMinor() == 1);
292     CHECK(version3.GetPatch() == 12);
293 
294     arm::pipe::Version version4(0);
295 
296     CHECK(version4.GetMajor() == 0);
297     CHECK(version4.GetMinor() == 0);
298     CHECK(version4.GetPatch() == 0);
299 
300     arm::pipe::Version version5(1, 0, 0);
301     CHECK(version5.GetEncodedValue() == 4194304);
302 }
303 
304 TEST_CASE("CheckPacketClass")
305 {
306     uint32_t length                              = 4;
307     std::unique_ptr<unsigned char[]> packetData0 = std::make_unique<unsigned char[]>(length);
308     std::unique_ptr<unsigned char[]> packetData1 = std::make_unique<unsigned char[]>(0);
309     std::unique_ptr<unsigned char[]> nullPacketData;
310 
311     arm::pipe::Packet packetTest0(472580096, length, packetData0);
312 
313     CHECK(packetTest0.GetHeader() == 472580096);
314     CHECK(packetTest0.GetPacketFamily() == 7);
315     CHECK(packetTest0.GetPacketId() == 43);
316     CHECK(packetTest0.GetLength() == length);
317     CHECK(packetTest0.GetPacketType() == 3);
318     CHECK(packetTest0.GetPacketClass() == 5);
319 
320     CHECK_THROWS_AS(arm::pipe::Packet packetTest1(472580096, 0, packetData1), arm::pipe::InvalidArgumentException);
321     CHECK_NOTHROW(arm::pipe::Packet packetTest2(472580096, 0, nullPacketData));
322 
323     arm::pipe::Packet packetTest3(472580096, 0, nullPacketData);
324     CHECK(packetTest3.GetLength() == 0);
325     CHECK(packetTest3.GetData() == nullptr);
326 
327     const unsigned char* packetTest0Data = packetTest0.GetData();
328     arm::pipe::Packet packetTest4(std::move(packetTest0));
329 
330     CHECK(packetTest0.GetData() == nullptr);
331     CHECK(packetTest4.GetData() == packetTest0Data);
332 
333     CHECK(packetTest4.GetHeader() == 472580096);
334     CHECK(packetTest4.GetPacketFamily() == 7);
335     CHECK(packetTest4.GetPacketId() == 43);
336     CHECK(packetTest4.GetLength() == length);
337     CHECK(packetTest4.GetPacketType() == 3);
338     CHECK(packetTest4.GetPacketClass() == 5);
339 }
340 
341 TEST_CASE("CheckCommandHandlerFunctor")
342 {
343     // Hard code the version as it will be the same during a single profiling session
344     uint32_t version = 1;
345 
346     TestFunctorA testFunctorA(7, 461, version);
347     TestFunctorB testFunctorB(8, 963, version);
348     TestFunctorC testFunctorC(5, 983, version);
349 
350     arm::pipe::CommandHandlerKey keyA(
351         testFunctorA.GetFamilyId(), testFunctorA.GetPacketId(), testFunctorA.GetVersion());
352     arm::pipe::CommandHandlerKey keyB(
353         testFunctorB.GetFamilyId(), testFunctorB.GetPacketId(), testFunctorB.GetVersion());
354     arm::pipe::CommandHandlerKey keyC(
355         testFunctorC.GetFamilyId(), testFunctorC.GetPacketId(), testFunctorC.GetVersion());
356 
357     // Create the unwrapped map to simulate the Command Handler Registry
358     std::map<arm::pipe::CommandHandlerKey, arm::pipe::CommandHandlerFunctor*> registry;
359 
360     registry.insert(std::make_pair(keyB, &testFunctorB));
361     registry.insert(std::make_pair(keyA, &testFunctorA));
362     registry.insert(std::make_pair(keyC, &testFunctorC));
363 
364     // Check the order of the map is correct
365     auto it = registry.begin();
366     CHECK(it->first == keyC);    // familyId == 5
367     it++;
368     CHECK(it->first == keyA);    // familyId == 7
369     it++;
370     CHECK(it->first == keyB);    // familyId == 8
371 
372     std::unique_ptr<unsigned char[]> packetDataA;
373     std::unique_ptr<unsigned char[]> packetDataB;
374     std::unique_ptr<unsigned char[]> packetDataC;
375 
376     arm::pipe::Packet packetA(500000000, 0, packetDataA);
377     arm::pipe::Packet packetB(600000000, 0, packetDataB);
378     arm::pipe::Packet packetC(400000000, 0, packetDataC);
379 
380     // Check the correct operator of derived class is called
381     registry.at(arm::pipe::CommandHandlerKey(
382         packetA.GetPacketFamily(), packetA.GetPacketId(), version))->operator()(packetA);
383     CHECK(testFunctorA.GetCount() == 1);
384     CHECK(testFunctorB.GetCount() == 0);
385     CHECK(testFunctorC.GetCount() == 0);
386 
387     registry.at(arm::pipe::CommandHandlerKey(
388         packetB.GetPacketFamily(), packetB.GetPacketId(), version))->operator()(packetB);
389     CHECK(testFunctorA.GetCount() == 1);
390     CHECK(testFunctorB.GetCount() == 1);
391     CHECK(testFunctorC.GetCount() == 0);
392 
393     registry.at(arm::pipe::CommandHandlerKey(
394         packetC.GetPacketFamily(), packetC.GetPacketId(), version))->operator()(packetC);
395     CHECK(testFunctorA.GetCount() == 1);
396     CHECK(testFunctorB.GetCount() == 1);
397     CHECK(testFunctorC.GetCount() == 1);
398 }
399 
400 TEST_CASE("CheckCommandHandlerRegistry")
401 {
402     // Hard code the version as it will be the same during a single profiling session
403     uint32_t version = 1;
404 
405     TestFunctorA testFunctorA(7, 461, version);
406     TestFunctorB testFunctorB(8, 963, version);
407     TestFunctorC testFunctorC(5, 983, version);
408 
409     // Create the Command Handler Registry
410     arm::pipe::CommandHandlerRegistry registry;
411 
412     // Register multiple different derived classes
413     registry.RegisterFunctor(&testFunctorA);
414     registry.RegisterFunctor(&testFunctorB);
415     registry.RegisterFunctor(&testFunctorC);
416 
417     std::unique_ptr<unsigned char[]> packetDataA;
418     std::unique_ptr<unsigned char[]> packetDataB;
419     std::unique_ptr<unsigned char[]> packetDataC;
420 
421     arm::pipe::Packet packetA(500000000, 0, packetDataA);
422     arm::pipe::Packet packetB(600000000, 0, packetDataB);
423     arm::pipe::Packet packetC(400000000, 0, packetDataC);
424 
425     // Check the correct operator of derived class is called
426     registry.GetFunctor(packetA.GetPacketFamily(), packetA.GetPacketId(), version)->operator()(packetA);
427     CHECK(testFunctorA.GetCount() == 1);
428     CHECK(testFunctorB.GetCount() == 0);
429     CHECK(testFunctorC.GetCount() == 0);
430 
431     registry.GetFunctor(packetB.GetPacketFamily(), packetB.GetPacketId(), version)->operator()(packetB);
432     CHECK(testFunctorA.GetCount() == 1);
433     CHECK(testFunctorB.GetCount() == 1);
434     CHECK(testFunctorC.GetCount() == 0);
435 
436     registry.GetFunctor(packetC.GetPacketFamily(), packetC.GetPacketId(), version)->operator()(packetC);
437     CHECK(testFunctorA.GetCount() == 1);
438     CHECK(testFunctorB.GetCount() == 1);
439     CHECK(testFunctorC.GetCount() == 1);
440 
441     // Re-register an existing key with a new function
442     registry.RegisterFunctor(&testFunctorC, testFunctorA.GetFamilyId(), testFunctorA.GetPacketId(), version);
443     registry.GetFunctor(packetA.GetPacketFamily(), packetA.GetPacketId(), version)->operator()(packetC);
444     CHECK(testFunctorA.GetCount() == 1);
445     CHECK(testFunctorB.GetCount() == 1);
446     CHECK(testFunctorC.GetCount() == 2);
447 
448     // Check that non-existent key returns nullptr for its functor
449     CHECK_THROWS_AS(registry.GetFunctor(0, 0, 0), arm::pipe::ProfilingException);
450 }
451 
452 TEST_CASE("CheckPacketVersionResolver")
453 {
454     // Set up random number generator for generating packetId values
455     std::random_device device;
456     std::mt19937 generator(device());
457     std::uniform_int_distribution<uint32_t> distribution(std::numeric_limits<uint32_t>::min(),
458                                                          std::numeric_limits<uint32_t>::max());
459 
460     // NOTE: Expected version is always 1.0.0, regardless of packetId
461     const arm::pipe::Version expectedVersion(1, 0, 0);
462 
463     arm::pipe::PacketVersionResolver packetVersionResolver;
464 
465     constexpr unsigned int numTests = 10u;
466 
467     for (unsigned int i = 0u; i < numTests; ++i)
468     {
469         const uint32_t familyId = distribution(generator);
470         const uint32_t packetId = distribution(generator);
471         arm::pipe::Version resolvedVersion = packetVersionResolver.ResolvePacketVersion(familyId, packetId);
472 
473         CHECK(resolvedVersion == expectedVersion);
474     }
475 }
476 
ProfilingCurrentStateThreadImpl(ProfilingStateMachine & states)477 void ProfilingCurrentStateThreadImpl(ProfilingStateMachine& states)
478 {
479     ProfilingState newState = ProfilingState::NotConnected;
480     states.GetCurrentState();
481     states.TransitionToState(newState);
482 }
483 
484 TEST_CASE("CheckProfilingStateMachine")
485 {
486     ProfilingStateMachine profilingState1(ProfilingState::Uninitialised);
487     profilingState1.TransitionToState(ProfilingState::Uninitialised);
488     CHECK(profilingState1.GetCurrentState() == ProfilingState::Uninitialised);
489 
490     ProfilingStateMachine profilingState2(ProfilingState::Uninitialised);
491     profilingState2.TransitionToState(ProfilingState::NotConnected);
492     CHECK(profilingState2.GetCurrentState() == ProfilingState::NotConnected);
493 
494     ProfilingStateMachine profilingState3(ProfilingState::NotConnected);
495     profilingState3.TransitionToState(ProfilingState::NotConnected);
496     CHECK(profilingState3.GetCurrentState() == ProfilingState::NotConnected);
497 
498     ProfilingStateMachine profilingState4(ProfilingState::NotConnected);
499     profilingState4.TransitionToState(ProfilingState::WaitingForAck);
500     CHECK(profilingState4.GetCurrentState() == ProfilingState::WaitingForAck);
501 
502     ProfilingStateMachine profilingState5(ProfilingState::WaitingForAck);
503     profilingState5.TransitionToState(ProfilingState::WaitingForAck);
504     CHECK(profilingState5.GetCurrentState() == ProfilingState::WaitingForAck);
505 
506     ProfilingStateMachine profilingState6(ProfilingState::WaitingForAck);
507     profilingState6.TransitionToState(ProfilingState::Active);
508     CHECK(profilingState6.GetCurrentState() == ProfilingState::Active);
509 
510     ProfilingStateMachine profilingState7(ProfilingState::Active);
511     profilingState7.TransitionToState(ProfilingState::NotConnected);
512     CHECK(profilingState7.GetCurrentState() == ProfilingState::NotConnected);
513 
514     ProfilingStateMachine profilingState8(ProfilingState::Active);
515     profilingState8.TransitionToState(ProfilingState::Active);
516     CHECK(profilingState8.GetCurrentState() == ProfilingState::Active);
517 
518     ProfilingStateMachine profilingState9(ProfilingState::Uninitialised);
519     CHECK_THROWS_AS(profilingState9.TransitionToState(ProfilingState::WaitingForAck), arm::pipe::ProfilingException);
520 
521     ProfilingStateMachine profilingState10(ProfilingState::Uninitialised);
522     CHECK_THROWS_AS(profilingState10.TransitionToState(ProfilingState::Active), arm::pipe::ProfilingException);
523 
524     ProfilingStateMachine profilingState11(ProfilingState::NotConnected);
525     CHECK_THROWS_AS(profilingState11.TransitionToState(ProfilingState::Uninitialised), arm::pipe::ProfilingException);
526 
527     ProfilingStateMachine profilingState12(ProfilingState::NotConnected);
528     CHECK_THROWS_AS(profilingState12.TransitionToState(ProfilingState::Active), arm::pipe::ProfilingException);
529 
530     ProfilingStateMachine profilingState13(ProfilingState::WaitingForAck);
531     CHECK_THROWS_AS(profilingState13.TransitionToState(ProfilingState::Uninitialised), arm::pipe::ProfilingException);
532 
533     ProfilingStateMachine profilingState14(ProfilingState::WaitingForAck);
534     profilingState14.TransitionToState(ProfilingState::NotConnected);
535     CHECK(profilingState14.GetCurrentState() == ProfilingState::NotConnected);
536 
537     ProfilingStateMachine profilingState15(ProfilingState::Active);
538     CHECK_THROWS_AS(profilingState15.TransitionToState(ProfilingState::Uninitialised), arm::pipe::ProfilingException);
539 
540     ProfilingStateMachine profilingState16(ProfilingState::Active);
541     CHECK_THROWS_AS(profilingState16.TransitionToState(ProfilingState::WaitingForAck), arm::pipe::ProfilingException);
542 
543     ProfilingStateMachine profilingState17(ProfilingState::Uninitialised);
544 
545     std::vector<std::thread> threads;
546     for (unsigned int i = 0; i < 5; ++i)
547     {
548         threads.push_back(std::thread(ProfilingCurrentStateThreadImpl, std::ref(profilingState17)));
549     }
550     std::for_each(threads.begin(), threads.end(), [](std::thread& theThread)
__anona15fbe1f0102(std::thread& theThread) 551     {
552         theThread.join();
553     });
554 
555     CHECK((profilingState17.GetCurrentState() == ProfilingState::NotConnected));
556 }
557 
CaptureDataWriteThreadImpl(Holder & holder,uint32_t capturePeriod,const std::vector<uint16_t> & counterIds)558 void CaptureDataWriteThreadImpl(Holder& holder, uint32_t capturePeriod, const std::vector<uint16_t>& counterIds)
559 {
560     holder.SetCaptureData(capturePeriod, counterIds, {});
561 }
562 
CaptureDataReadThreadImpl(const Holder & holder,CaptureData & captureData)563 void CaptureDataReadThreadImpl(const Holder& holder, CaptureData& captureData)
564 {
565     captureData = holder.GetCaptureData();
566 }
567 
568 TEST_CASE("CheckCaptureDataHolder")
569 {
570     std::map<uint32_t, std::vector<uint16_t>> periodIdMap;
571     std::vector<uint16_t> counterIds;
572     uint32_t numThreads = 10;
573     for (uint32_t i = 0; i < numThreads; ++i)
574     {
575         counterIds.emplace_back(i);
576         periodIdMap.insert(std::make_pair(i, counterIds));
577     }
578 
579     // Verify the read and write threads set the holder correctly
580     // and retrieve the expected values
581     Holder holder;
582     CHECK((holder.GetCaptureData()).GetCapturePeriod() == 0);
583     CHECK(((holder.GetCaptureData()).GetCounterIds()).empty());
584 
585     // Check Holder functions
586     std::thread thread1(CaptureDataWriteThreadImpl, std::ref(holder), 2, std::ref(periodIdMap[2]));
587     thread1.join();
588     CHECK((holder.GetCaptureData()).GetCapturePeriod() == 2);
589     CHECK((holder.GetCaptureData()).GetCounterIds() == periodIdMap[2]);
590     // NOTE: now that we have some initial values in the holder we don't have to worry
591     //       in the multi-threaded section below about a read thread accessing the holder
592     //       before any write thread has gotten to it so we read period = 0, counterIds empty
593     //       instead of period = 0, counterIds = {0} as will the case when write thread 0
594     //       has executed.
595 
596     CaptureData captureData;
597     std::thread thread2(CaptureDataReadThreadImpl, std::ref(holder), std::ref(captureData));
598     thread2.join();
599     CHECK(captureData.GetCapturePeriod() == 2);
600     CHECK(captureData.GetCounterIds() == periodIdMap[2]);
601 
602     std::map<uint32_t, CaptureData> captureDataIdMap;
603     for (uint32_t i = 0; i < numThreads; ++i)
604     {
605         CaptureData perThreadCaptureData;
606         captureDataIdMap.insert(std::make_pair(i, perThreadCaptureData));
607     }
608 
609     std::vector<std::thread> threadsVect;
610     std::vector<std::thread> readThreadsVect;
611     for (uint32_t i = 0; i < numThreads; ++i)
612     {
613         threadsVect.emplace_back(
614             std::thread(CaptureDataWriteThreadImpl, std::ref(holder), i, std::ref(periodIdMap[i])));
615 
616         // Verify that the CaptureData goes into the thread in a virgin state
617         CHECK(captureDataIdMap.at(i).GetCapturePeriod() == 0);
618         CHECK(captureDataIdMap.at(i).GetCounterIds().empty());
619         readThreadsVect.emplace_back(
620             std::thread(CaptureDataReadThreadImpl, std::ref(holder), std::ref(captureDataIdMap.at(i))));
621     }
622 
623     for (uint32_t i = 0; i < numThreads; ++i)
624     {
625         threadsVect[i].join();
626         readThreadsVect[i].join();
627     }
628 
629     // Look at the CaptureData that each read thread has filled
630     // the capture period it read should match the counter ids entry
631     for (uint32_t i = 0; i < numThreads; ++i)
632     {
633         CaptureData perThreadCaptureData = captureDataIdMap.at(i);
634         CHECK(perThreadCaptureData.GetCounterIds() == periodIdMap.at(perThreadCaptureData.GetCapturePeriod()));
635     }
636 }
637 
638 TEST_CASE("CaptureDataMethods")
639 {
640     // Check CaptureData setter and getter functions
641     std::vector<uint16_t> counterIds = { 42, 29, 13 };
642     CaptureData captureData;
643     CHECK(captureData.GetCapturePeriod() == 0);
644     CHECK((captureData.GetCounterIds()).empty());
645     captureData.SetCapturePeriod(150);
646     captureData.SetCounterIds(counterIds);
647     CHECK(captureData.GetCapturePeriod() == 150);
648     CHECK(captureData.GetCounterIds() == counterIds);
649 
650     // Check assignment operator
651     CaptureData secondCaptureData;
652 
653     secondCaptureData = captureData;
654     CHECK(secondCaptureData.GetCapturePeriod() == 150);
655     CHECK(secondCaptureData.GetCounterIds() == counterIds);
656 
657     // Check copy constructor
658     CaptureData copyConstructedCaptureData(captureData);
659 
660     CHECK(copyConstructedCaptureData.GetCapturePeriod() == 150);
661     CHECK(copyConstructedCaptureData.GetCounterIds() == counterIds);
662 }
663 
664 TEST_CASE("CheckProfilingServiceDisabled")
665 {
666     ProfilingOptions options;
667     armnn::ArmNNProfilingServiceInitialiser initialiser;
668     ProfilingService profilingService(arm::pipe::MAX_ARMNN_COUNTER,
669                                       initialiser,
670                                       arm::pipe::ARMNN_SOFTWARE_INFO,
671                                       arm::pipe::ARMNN_SOFTWARE_VERSION,
672                                       arm::pipe::ARMNN_HARDWARE_VERSION);
673     profilingService.ResetExternalProfilingOptions(options, true);
674     CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised);
675     profilingService.Update();
676     CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised);
677 }
678 
679 TEST_CASE("CheckProfilingServiceCounterDirectory")
680 {
681     ProfilingOptions options;
682     armnn::ArmNNProfilingServiceInitialiser initialiser;
683     ProfilingService profilingService(arm::pipe::MAX_ARMNN_COUNTER,
684                                       initialiser,
685                                       arm::pipe::ARMNN_SOFTWARE_INFO,
686                                       arm::pipe::ARMNN_SOFTWARE_VERSION,
687                                       arm::pipe::ARMNN_HARDWARE_VERSION);
688     profilingService.ResetExternalProfilingOptions(options, true);
689 
690     const ICounterDirectory& counterDirectory0 = profilingService.GetCounterDirectory();
691     CHECK(counterDirectory0.GetCounterCount() == 0);
692     profilingService.Update();
693     CHECK(counterDirectory0.GetCounterCount() == 0);
694 
695     options.m_EnableProfiling = true;
696     profilingService.ResetExternalProfilingOptions(options);
697 
698     const ICounterDirectory& counterDirectory1 = profilingService.GetCounterDirectory();
699     CHECK(counterDirectory1.GetCounterCount() == 0);
700     profilingService.Update();
701     CHECK(counterDirectory1.GetCounterCount() != 0);
702     // Reset the profiling service to stop any running thread
703     options.m_EnableProfiling = false;
704     profilingService.ResetExternalProfilingOptions(options, true);
705 }
706 
707 TEST_CASE("CheckProfilingServiceCounterValues")
708 {
709     ProfilingOptions options;
710     options.m_EnableProfiling          = true;
711     armnn::ArmNNProfilingServiceInitialiser initialiser;
712     ProfilingService profilingService(arm::pipe::MAX_ARMNN_COUNTER,
713                                       initialiser,
714                                       arm::pipe::ARMNN_SOFTWARE_INFO,
715                                       arm::pipe::ARMNN_SOFTWARE_VERSION,
716                                       arm::pipe::ARMNN_HARDWARE_VERSION);
717     profilingService.ResetExternalProfilingOptions(options, true);
718 
719     profilingService.Update();
720     const ICounterDirectory& counterDirectory = profilingService.GetCounterDirectory();
721     const Counters& counters                  = counterDirectory.GetCounters();
722     CHECK(!counters.empty());
723 
724     std::vector<std::thread> writers;
725 
726     CHECK(!counters.empty());
727     uint16_t inferencesRun = INFERENCES_RUN;
728 
729     // Test GetAbsoluteCounterValue
730     for (int i = 0; i < 4; ++i)
731     {
732         // Increment and decrement the INFERENCES_RUN counter 250 times
733         writers.push_back(std::thread([&profilingService, inferencesRun]()
__anona15fbe1f0202() 734                                       {
735                                           for (int i = 0; i < 250; ++i)
736                                           {
737                                               profilingService.IncrementCounterValue(inferencesRun);
738                                           }
739                                       }));
740         // Add 10 to the INFERENCES_RUN counter 200 times
741         writers.push_back(std::thread([&profilingService, inferencesRun]()
__anona15fbe1f0302() 742                                       {
743                                           for (int i = 0; i < 200; ++i)
744                                           {
745                                               profilingService.AddCounterValue(inferencesRun, 10);
746                                           }
747                                       }));
748         // Subtract 5 from the INFERENCES_RUN counter 200 times
749         writers.push_back(std::thread([&profilingService, inferencesRun]()
__anona15fbe1f0402() 750                                       {
751                                           for (int i = 0; i < 200; ++i)
752                                           {
753                                               profilingService.SubtractCounterValue(inferencesRun, 5);
754                                           }
755                                       }));
756     }
757     std::for_each(writers.begin(), writers.end(), mem_fn(&std::thread::join));
758 
759     uint32_t absoluteCounterValue = 0;
760 
761     CHECK_NOTHROW(absoluteCounterValue = profilingService.GetAbsoluteCounterValue(INFERENCES_RUN));
762     CHECK(absoluteCounterValue == 5000);
763 
764     // Test SetCounterValue
765     CHECK_NOTHROW(profilingService.SetCounterValue(INFERENCES_RUN, 0));
766     CHECK_NOTHROW(absoluteCounterValue = profilingService.GetAbsoluteCounterValue(INFERENCES_RUN));
767     CHECK(absoluteCounterValue == 0);
768 
769     // Test GetDeltaCounterValue
770     writers.clear();
771     uint32_t deltaCounterValue = 0;
772     //Start a reading thread to randomly read the INFERENCES_RUN counter value
773     std::thread reader([&profilingService, inferencesRun](uint32_t& deltaCounterValue)
__anona15fbe1f0502(uint32_t& deltaCounterValue) 774                        {
775                            for (int i = 0; i < 300; ++i)
776                            {
777                                deltaCounterValue += profilingService.GetDeltaCounterValue(inferencesRun);
778                            }
779                        }, std::ref(deltaCounterValue));
780 
781     for (int i = 0; i < 4; ++i)
782     {
783         // Increment and decrement the INFERENCES_RUN counter 250 times
784         writers.push_back(std::thread([&profilingService, inferencesRun]()
__anona15fbe1f0602() 785                                       {
786                                           for (int i = 0; i < 250; ++i)
787                                           {
788                                               profilingService.IncrementCounterValue(inferencesRun);
789                                           }
790                                       }));
791         // Add 10 to the INFERENCES_RUN counter 200 times
792         writers.push_back(std::thread([&profilingService, inferencesRun]()
__anona15fbe1f0702() 793                                       {
794                                           for (int i = 0; i < 200; ++i)
795                                           {
796                                               profilingService.AddCounterValue(inferencesRun, 10);
797                                           }
798                                       }));
799         // Subtract 5 from the INFERENCES_RUN counter 200 times
800         writers.push_back(std::thread([&profilingService, inferencesRun]()
__anona15fbe1f0802() 801                                       {
802                                           for (int i = 0; i < 200; ++i)
803                                           {
804                                               profilingService.SubtractCounterValue(inferencesRun, 5);
805                                           }
806                                       }));
807     }
808 
809     std::for_each(writers.begin(), writers.end(), mem_fn(&std::thread::join));
810     reader.join();
811 
812     // Do one last read in case the reader stopped early
813     deltaCounterValue += profilingService.GetDeltaCounterValue(INFERENCES_RUN);
814     CHECK(deltaCounterValue == 5000);
815 
816     // Reset the profiling service to stop any running thread
817     options.m_EnableProfiling = false;
818     profilingService.ResetExternalProfilingOptions(options, true);
819 }
820 
821 TEST_CASE("CheckProfilingObjectUids")
822 {
823     uint16_t uid = 0;
824     CHECK_NOTHROW(uid = GetNextUid());
825     CHECK(uid >= 1);
826 
827     uint16_t nextUid = 0;
828     CHECK_NOTHROW(nextUid = GetNextUid());
829     CHECK(nextUid > uid);
830 
831     std::vector<uint16_t> counterUids;
832     CHECK_NOTHROW(counterUids = GetNextCounterUids(uid,0));
833     CHECK(counterUids.size() == 1);
834 
835     std::vector<uint16_t> nextCounterUids;
836     CHECK_NOTHROW(nextCounterUids = GetNextCounterUids(nextUid, 2));
837     CHECK(nextCounterUids.size() == 2);
838     CHECK(nextCounterUids[0] > counterUids[0]);
839 
840     std::vector<uint16_t> counterUidsMultiCore;
841     uint16_t thirdUid = nextCounterUids[0];
842     uint16_t numberOfCores = 13;
843     CHECK_NOTHROW(counterUidsMultiCore = GetNextCounterUids(thirdUid, numberOfCores));
844     CHECK(counterUidsMultiCore.size() == numberOfCores);
845     CHECK(counterUidsMultiCore.front() >= nextCounterUids[0]);
846     for (size_t i = 1; i < numberOfCores; i++)
847     {
848         CHECK(counterUidsMultiCore[i] == counterUidsMultiCore[i - 1] + 1);
849     }
850     CHECK(counterUidsMultiCore.back() == counterUidsMultiCore.front() + numberOfCores - 1);
851 }
852 
853 TEST_CASE("CheckCounterDirectoryRegisterCategory")
854 {
855     CounterDirectory counterDirectory;
856     CHECK(counterDirectory.GetCategoryCount() == 0);
857     CHECK(counterDirectory.GetDeviceCount() == 0);
858     CHECK(counterDirectory.GetCounterSetCount() == 0);
859     CHECK(counterDirectory.GetCounterCount() == 0);
860 
861     // Register a category with an invalid name
862     const Category* noCategory = nullptr;
863     CHECK_THROWS_AS(noCategory = counterDirectory.RegisterCategory(""), arm::pipe::InvalidArgumentException);
864     CHECK(counterDirectory.GetCategoryCount() == 0);
865     CHECK(!noCategory);
866 
867     // Register a category with an invalid name
868     CHECK_THROWS_AS(noCategory = counterDirectory.RegisterCategory("invalid category"),
869                       arm::pipe::InvalidArgumentException);
870     CHECK(counterDirectory.GetCategoryCount() == 0);
871     CHECK(!noCategory);
872 
873     // Register a new category
874     const std::string categoryName = "some_category";
875     const Category* category       = nullptr;
876     CHECK_NOTHROW(category = counterDirectory.RegisterCategory(categoryName));
877     CHECK(counterDirectory.GetCategoryCount() == 1);
878     CHECK(category);
879     CHECK(category->m_Name == categoryName);
880     CHECK(category->m_Counters.empty());
881 
882     // Get the registered category
883     const Category* registeredCategory = counterDirectory.GetCategory(categoryName);
884     CHECK(counterDirectory.GetCategoryCount() == 1);
885     CHECK(registeredCategory);
886     CHECK(registeredCategory == category);
887 
888     // Try to get a category not registered
889     const Category* notRegisteredCategory = counterDirectory.GetCategory("not_registered_category");
890     CHECK(counterDirectory.GetCategoryCount() == 1);
891     CHECK(!notRegisteredCategory);
892 
893     // Register a category already registered
894     const Category* anotherCategory = nullptr;
895     CHECK_THROWS_AS(anotherCategory = counterDirectory.RegisterCategory(categoryName),
896                       arm::pipe::InvalidArgumentException);
897     CHECK(counterDirectory.GetCategoryCount() == 1);
898     CHECK(!anotherCategory);
899 
900     // Register a device for testing
901     const std::string deviceName = "some_device";
902     const Device* device         = nullptr;
903     CHECK_NOTHROW(device = counterDirectory.RegisterDevice(deviceName));
904     CHECK(counterDirectory.GetDeviceCount() == 1);
905     CHECK(device);
906     CHECK(device->m_Uid >= 1);
907     CHECK(device->m_Name == deviceName);
908     CHECK(device->m_Cores == 0);
909 
910     // Register a new category not associated to any device
911     const std::string categoryWoDeviceName = "some_category_without_device";
912     const Category* categoryWoDevice       = nullptr;
913     CHECK_NOTHROW(categoryWoDevice = counterDirectory.RegisterCategory(categoryWoDeviceName));
914     CHECK(counterDirectory.GetCategoryCount() == 2);
915     CHECK(categoryWoDevice);
916     CHECK(categoryWoDevice->m_Name == categoryWoDeviceName);
917     CHECK(categoryWoDevice->m_Counters.empty());
918 
919     // Register a new category associated to an invalid device name (already exist)
920     const Category* categoryInvalidDeviceName = nullptr;
921     CHECK_THROWS_AS(categoryInvalidDeviceName =
922                           counterDirectory.RegisterCategory(categoryWoDeviceName),
923                       arm::pipe::InvalidArgumentException);
924     CHECK(counterDirectory.GetCategoryCount() == 2);
925     CHECK(!categoryInvalidDeviceName);
926 
927     // Register a new category associated to a valid device
928     const std::string categoryWValidDeviceName = "some_category_with_valid_device";
929     const Category* categoryWValidDevice       = nullptr;
930     CHECK_NOTHROW(categoryWValidDevice =
931                              counterDirectory.RegisterCategory(categoryWValidDeviceName));
932     CHECK(counterDirectory.GetCategoryCount() == 3);
933     CHECK(categoryWValidDevice);
934     CHECK(categoryWValidDevice != category);
935     CHECK(categoryWValidDevice->m_Name == categoryWValidDeviceName);
936 
937     // Register a counter set for testing
938     const std::string counterSetName = "some_counter_set";
939     const CounterSet* counterSet     = nullptr;
940     CHECK_NOTHROW(counterSet = counterDirectory.RegisterCounterSet(counterSetName));
941     CHECK(counterDirectory.GetCounterSetCount() == 1);
942     CHECK(counterSet);
943     CHECK(counterSet->m_Uid >= 1);
944     CHECK(counterSet->m_Name == counterSetName);
945     CHECK(counterSet->m_Count == 0);
946 
947     // Register a new category not associated to any counter set
948     const std::string categoryWoCounterSetName = "some_category_without_counter_set";
949     const Category* categoryWoCounterSet       = nullptr;
950     CHECK_NOTHROW(categoryWoCounterSet =
951                              counterDirectory.RegisterCategory(categoryWoCounterSetName));
952     CHECK(counterDirectory.GetCategoryCount() == 4);
953     CHECK(categoryWoCounterSet);
954     CHECK(categoryWoCounterSet->m_Name == categoryWoCounterSetName);
955 
956     // Register a new category associated to a valid counter set
957     const std::string categoryWValidCounterSetName = "some_category_with_valid_counter_set";
958     const Category* categoryWValidCounterSet       = nullptr;
959     CHECK_NOTHROW(categoryWValidCounterSet = counterDirectory.RegisterCategory(categoryWValidCounterSetName));
960     CHECK(counterDirectory.GetCategoryCount() == 5);
961     CHECK(categoryWValidCounterSet);
962     CHECK(categoryWValidCounterSet != category);
963     CHECK(categoryWValidCounterSet->m_Name == categoryWValidCounterSetName);
964 
965     // Register a new category associated to a valid device and counter set
966     const std::string categoryWValidDeviceAndValidCounterSetName = "some_category_with_valid_device_and_counter_set";
967     const Category* categoryWValidDeviceAndValidCounterSet       = nullptr;
968     CHECK_NOTHROW(categoryWValidDeviceAndValidCounterSet = counterDirectory.RegisterCategory(
969                              categoryWValidDeviceAndValidCounterSetName));
970     CHECK(counterDirectory.GetCategoryCount() == 6);
971     CHECK(categoryWValidDeviceAndValidCounterSet);
972     CHECK(categoryWValidDeviceAndValidCounterSet != category);
973     CHECK(categoryWValidDeviceAndValidCounterSet->m_Name == categoryWValidDeviceAndValidCounterSetName);
974 }
975 
976 TEST_CASE("CheckCounterDirectoryRegisterDevice")
977 {
978     CounterDirectory counterDirectory;
979     CHECK(counterDirectory.GetCategoryCount() == 0);
980     CHECK(counterDirectory.GetDeviceCount() == 0);
981     CHECK(counterDirectory.GetCounterSetCount() == 0);
982     CHECK(counterDirectory.GetCounterCount() == 0);
983 
984     // Register a device with an invalid name
985     const Device* noDevice = nullptr;
986     CHECK_THROWS_AS(noDevice = counterDirectory.RegisterDevice(""), arm::pipe::InvalidArgumentException);
987     CHECK(counterDirectory.GetDeviceCount() == 0);
988     CHECK(!noDevice);
989 
990     // Register a device with an invalid name
991     CHECK_THROWS_AS(noDevice = counterDirectory.RegisterDevice("inv@lid nam€"), arm::pipe::InvalidArgumentException);
992     CHECK(counterDirectory.GetDeviceCount() == 0);
993     CHECK(!noDevice);
994 
995     // Register a new device with no cores or parent category
996     const std::string deviceName = "some_device";
997     const Device* device         = nullptr;
998     CHECK_NOTHROW(device = counterDirectory.RegisterDevice(deviceName));
999     CHECK(counterDirectory.GetDeviceCount() == 1);
1000     CHECK(device);
1001     CHECK(device->m_Name == deviceName);
1002     CHECK(device->m_Uid >= 1);
1003     CHECK(device->m_Cores == 0);
1004 
1005     // Try getting an unregistered device
1006     const Device* unregisteredDevice = counterDirectory.GetDevice(9999);
1007     CHECK(!unregisteredDevice);
1008 
1009     // Get the registered device
1010     const Device* registeredDevice = counterDirectory.GetDevice(device->m_Uid);
1011     CHECK(counterDirectory.GetDeviceCount() == 1);
1012     CHECK(registeredDevice);
1013     CHECK(registeredDevice == device);
1014 
1015     // Register a device with the name of a device already registered
1016     const Device* deviceSameName = nullptr;
1017     CHECK_THROWS_AS(deviceSameName = counterDirectory.RegisterDevice(deviceName),
1018                                      arm::pipe::InvalidArgumentException);
1019     CHECK(counterDirectory.GetDeviceCount() == 1);
1020     CHECK(!deviceSameName);
1021 
1022     // Register a new device with cores and no parent category
1023     const std::string deviceWCoresName = "some_device_with_cores";
1024     const Device* deviceWCores         = nullptr;
1025     CHECK_NOTHROW(deviceWCores = counterDirectory.RegisterDevice(deviceWCoresName, 2));
1026     CHECK(counterDirectory.GetDeviceCount() == 2);
1027     CHECK(deviceWCores);
1028     CHECK(deviceWCores->m_Name == deviceWCoresName);
1029     CHECK(deviceWCores->m_Uid >= 1);
1030     CHECK(deviceWCores->m_Uid > device->m_Uid);
1031     CHECK(deviceWCores->m_Cores == 2);
1032 
1033     // Get the registered device
1034     const Device* registeredDeviceWCores = counterDirectory.GetDevice(deviceWCores->m_Uid);
1035     CHECK(counterDirectory.GetDeviceCount() == 2);
1036     CHECK(registeredDeviceWCores);
1037     CHECK(registeredDeviceWCores == deviceWCores);
1038     CHECK(registeredDeviceWCores != device);
1039 
1040     // Register a new device with cores and invalid parent category
1041     const std::string deviceWCoresWInvalidParentCategoryName = "some_device_with_cores_with_invalid_parent_category";
1042     const Device* deviceWCoresWInvalidParentCategory         = nullptr;
1043     CHECK_THROWS_AS(deviceWCoresWInvalidParentCategory =
1044                           counterDirectory.RegisterDevice(deviceWCoresWInvalidParentCategoryName, 3, std::string("")),
1045                       arm::pipe::InvalidArgumentException);
1046     CHECK(counterDirectory.GetDeviceCount() == 2);
1047     CHECK(!deviceWCoresWInvalidParentCategory);
1048 
1049     // Register a new device with cores and invalid parent category
1050     const std::string deviceWCoresWInvalidParentCategoryName2 = "some_device_with_cores_with_invalid_parent_category2";
1051     const Device* deviceWCoresWInvalidParentCategory2         = nullptr;
1052     CHECK_THROWS_AS(deviceWCoresWInvalidParentCategory2 = counterDirectory.RegisterDevice(
1053                           deviceWCoresWInvalidParentCategoryName2, 3, std::string("invalid_parent_category")),
1054                       arm::pipe::InvalidArgumentException);
1055     CHECK(counterDirectory.GetDeviceCount() == 2);
1056     CHECK(!deviceWCoresWInvalidParentCategory2);
1057 
1058     // Register a category for testing
1059     const std::string categoryName = "some_category";
1060     const Category* category       = nullptr;
1061     CHECK_NOTHROW(category = counterDirectory.RegisterCategory(categoryName));
1062     CHECK(counterDirectory.GetCategoryCount() == 1);
1063     CHECK(category);
1064     CHECK(category->m_Name == categoryName);
1065     CHECK(category->m_Counters.empty());
1066 
1067     // Register a new device with cores and valid parent category
1068     const std::string deviceWCoresWValidParentCategoryName = "some_device_with_cores_with_valid_parent_category";
1069     const Device* deviceWCoresWValidParentCategory         = nullptr;
1070     CHECK_NOTHROW(deviceWCoresWValidParentCategory =
1071                              counterDirectory.RegisterDevice(deviceWCoresWValidParentCategoryName, 4, categoryName));
1072     CHECK(counterDirectory.GetDeviceCount() == 3);
1073     CHECK(deviceWCoresWValidParentCategory);
1074     CHECK(deviceWCoresWValidParentCategory->m_Name == deviceWCoresWValidParentCategoryName);
1075     CHECK(deviceWCoresWValidParentCategory->m_Uid >= 1);
1076     CHECK(deviceWCoresWValidParentCategory->m_Uid > device->m_Uid);
1077     CHECK(deviceWCoresWValidParentCategory->m_Uid > deviceWCores->m_Uid);
1078     CHECK(deviceWCoresWValidParentCategory->m_Cores == 4);
1079 }
1080 
1081 TEST_CASE("CheckCounterDirectoryRegisterCounterSet")
1082 {
1083     CounterDirectory counterDirectory;
1084     CHECK(counterDirectory.GetCategoryCount() == 0);
1085     CHECK(counterDirectory.GetDeviceCount() == 0);
1086     CHECK(counterDirectory.GetCounterSetCount() == 0);
1087     CHECK(counterDirectory.GetCounterCount() == 0);
1088 
1089     // Register a counter set with an invalid name
1090     const CounterSet* noCounterSet = nullptr;
1091     CHECK_THROWS_AS(noCounterSet = counterDirectory.RegisterCounterSet(""), arm::pipe::InvalidArgumentException);
1092     CHECK(counterDirectory.GetCounterSetCount() == 0);
1093     CHECK(!noCounterSet);
1094 
1095     // Register a counter set with an invalid name
1096     CHECK_THROWS_AS(noCounterSet = counterDirectory.RegisterCounterSet("invalid name"),
1097                       arm::pipe::InvalidArgumentException);
1098     CHECK(counterDirectory.GetCounterSetCount() == 0);
1099     CHECK(!noCounterSet);
1100 
1101     // Register a new counter set with no count or parent category
1102     const std::string counterSetName = "some_counter_set";
1103     const CounterSet* counterSet     = nullptr;
1104     CHECK_NOTHROW(counterSet = counterDirectory.RegisterCounterSet(counterSetName));
1105     CHECK(counterDirectory.GetCounterSetCount() == 1);
1106     CHECK(counterSet);
1107     CHECK(counterSet->m_Name == counterSetName);
1108     CHECK(counterSet->m_Uid >= 1);
1109     CHECK(counterSet->m_Count == 0);
1110 
1111     // Try getting an unregistered counter set
1112     const CounterSet* unregisteredCounterSet = counterDirectory.GetCounterSet(9999);
1113     CHECK(!unregisteredCounterSet);
1114 
1115     // Get the registered counter set
1116     const CounterSet* registeredCounterSet = counterDirectory.GetCounterSet(counterSet->m_Uid);
1117     CHECK(counterDirectory.GetCounterSetCount() == 1);
1118     CHECK(registeredCounterSet);
1119     CHECK(registeredCounterSet == counterSet);
1120 
1121     // Register a counter set with the name of a counter set already registered
1122     const CounterSet* counterSetSameName = nullptr;
1123     CHECK_THROWS_AS(counterSetSameName = counterDirectory.RegisterCounterSet(counterSetName),
1124                       arm::pipe::InvalidArgumentException);
1125     CHECK(counterDirectory.GetCounterSetCount() == 1);
1126     CHECK(!counterSetSameName);
1127 
1128     // Register a new counter set with count and no parent category
1129     const std::string counterSetWCountName = "some_counter_set_with_count";
1130     const CounterSet* counterSetWCount     = nullptr;
1131     CHECK_NOTHROW(counterSetWCount = counterDirectory.RegisterCounterSet(counterSetWCountName, 37));
1132     CHECK(counterDirectory.GetCounterSetCount() == 2);
1133     CHECK(counterSetWCount);
1134     CHECK(counterSetWCount->m_Name == counterSetWCountName);
1135     CHECK(counterSetWCount->m_Uid >= 1);
1136     CHECK(counterSetWCount->m_Uid > counterSet->m_Uid);
1137     CHECK(counterSetWCount->m_Count == 37);
1138 
1139     // Get the registered counter set
1140     const CounterSet* registeredCounterSetWCount = counterDirectory.GetCounterSet(counterSetWCount->m_Uid);
1141     CHECK(counterDirectory.GetCounterSetCount() == 2);
1142     CHECK(registeredCounterSetWCount);
1143     CHECK(registeredCounterSetWCount == counterSetWCount);
1144     CHECK(registeredCounterSetWCount != counterSet);
1145 
1146     // Register a new counter set with count and invalid parent category
1147     const std::string counterSetWCountWInvalidParentCategoryName = "some_counter_set_with_count_"
1148                                                                    "with_invalid_parent_category";
1149     const CounterSet* counterSetWCountWInvalidParentCategory = nullptr;
1150     CHECK_THROWS_AS(counterSetWCountWInvalidParentCategory = counterDirectory.RegisterCounterSet(
1151                           counterSetWCountWInvalidParentCategoryName, 42, std::string("")),
1152                       arm::pipe::InvalidArgumentException);
1153     CHECK(counterDirectory.GetCounterSetCount() == 2);
1154     CHECK(!counterSetWCountWInvalidParentCategory);
1155 
1156     // Register a new counter set with count and invalid parent category
1157     const std::string counterSetWCountWInvalidParentCategoryName2 = "some_counter_set_with_count_"
1158                                                                     "with_invalid_parent_category2";
1159     const CounterSet* counterSetWCountWInvalidParentCategory2 = nullptr;
1160     CHECK_THROWS_AS(counterSetWCountWInvalidParentCategory2 = counterDirectory.RegisterCounterSet(
1161                           counterSetWCountWInvalidParentCategoryName2, 42, std::string("invalid_parent_category")),
1162                       arm::pipe::InvalidArgumentException);
1163     CHECK(counterDirectory.GetCounterSetCount() == 2);
1164     CHECK(!counterSetWCountWInvalidParentCategory2);
1165 
1166     // Register a category for testing
1167     const std::string categoryName = "some_category";
1168     const Category* category       = nullptr;
1169     CHECK_NOTHROW(category = counterDirectory.RegisterCategory(categoryName));
1170     CHECK(counterDirectory.GetCategoryCount() == 1);
1171     CHECK(category);
1172     CHECK(category->m_Name == categoryName);
1173     CHECK(category->m_Counters.empty());
1174 
1175     // Register a new counter set with count and valid parent category
1176     const std::string counterSetWCountWValidParentCategoryName = "some_counter_set_with_count_"
1177                                                                  "with_valid_parent_category";
1178     const CounterSet* counterSetWCountWValidParentCategory = nullptr;
1179     CHECK_NOTHROW(counterSetWCountWValidParentCategory = counterDirectory.RegisterCounterSet(
1180                              counterSetWCountWValidParentCategoryName, 42, categoryName));
1181     CHECK(counterDirectory.GetCounterSetCount() == 3);
1182     CHECK(counterSetWCountWValidParentCategory);
1183     CHECK(counterSetWCountWValidParentCategory->m_Name == counterSetWCountWValidParentCategoryName);
1184     CHECK(counterSetWCountWValidParentCategory->m_Uid >= 1);
1185     CHECK(counterSetWCountWValidParentCategory->m_Uid > counterSet->m_Uid);
1186     CHECK(counterSetWCountWValidParentCategory->m_Uid > counterSetWCount->m_Uid);
1187     CHECK(counterSetWCountWValidParentCategory->m_Count == 42);
1188 
1189     // Register a counter set associated to a category with invalid name
1190     const std::string counterSetSameCategoryName = "some_counter_set_with_invalid_parent_category";
1191     const std::string invalidCategoryName = "";
1192     const CounterSet* counterSetSameCategory     = nullptr;
1193     CHECK_THROWS_AS(counterSetSameCategory =
1194                           counterDirectory.RegisterCounterSet(counterSetSameCategoryName, 0, invalidCategoryName),
1195                       arm::pipe::InvalidArgumentException);
1196     CHECK(counterDirectory.GetCounterSetCount() == 3);
1197     CHECK(!counterSetSameCategory);
1198 }
1199 
1200 TEST_CASE("CheckCounterDirectoryRegisterCounter")
1201 {
1202     CounterDirectory counterDirectory;
1203     CHECK(counterDirectory.GetCategoryCount() == 0);
1204     CHECK(counterDirectory.GetDeviceCount() == 0);
1205     CHECK(counterDirectory.GetCounterSetCount() == 0);
1206     CHECK(counterDirectory.GetCounterCount() == 0);
1207 
1208     // Register a counter with an invalid parent category name
1209     const Counter* noCounter = nullptr;
1210     CHECK_THROWS_AS(noCounter =
1211                           counterDirectory.RegisterCounter(armnn::profiling::BACKEND_ID,
1212                                                            0,
1213                                                            "",
1214                                                            0,
1215                                                            1,
1216                                                            123.45f,
1217                                                            "valid ",
1218                                                            "name"),
1219                       arm::pipe::InvalidArgumentException);
1220     CHECK(counterDirectory.GetCounterCount() == 0);
1221     CHECK(!noCounter);
1222 
1223     // Register a counter with an invalid parent category name
1224     CHECK_THROWS_AS(noCounter = counterDirectory.RegisterCounter(armnn::profiling::BACKEND_ID,
1225                                                                  1,
1226                                                                  "invalid parent category",
1227                                                                  0,
1228                                                                  1,
1229                                                                  123.45f,
1230                                                                  "valid name",
1231                                                                  "valid description"),
1232                       arm::pipe::InvalidArgumentException);
1233     CHECK(counterDirectory.GetCounterCount() == 0);
1234     CHECK(!noCounter);
1235 
1236     // Register a counter with an invalid class
1237     CHECK_THROWS_AS(noCounter = counterDirectory.RegisterCounter(armnn::profiling::BACKEND_ID,
1238                                                                  2,
1239                                                                  "valid_parent_category",
1240                                                                  2,
1241                                                                  1,
1242                                                                  123.45f,
1243                                                                  "valid "
1244                                                                  "name",
1245                                                                  "valid description"),
1246                       arm::pipe::InvalidArgumentException);
1247     CHECK(counterDirectory.GetCounterCount() == 0);
1248     CHECK(!noCounter);
1249 
1250     // Register a counter with an invalid interpolation
1251     CHECK_THROWS_AS(noCounter = counterDirectory.RegisterCounter(armnn::profiling::BACKEND_ID,
1252                                                                  4,
1253                                                                  "valid_parent_category",
1254                                                                  0,
1255                                                                  3,
1256                                                                  123.45f,
1257                                                                  "valid "
1258                                                                  "name",
1259                                                                  "valid description"),
1260                       arm::pipe::InvalidArgumentException);
1261     CHECK(counterDirectory.GetCounterCount() == 0);
1262     CHECK(!noCounter);
1263 
1264     // Register a counter with an invalid multiplier
1265     CHECK_THROWS_AS(noCounter = counterDirectory.RegisterCounter(armnn::profiling::BACKEND_ID,
1266                                                                  5,
1267                                                                  "valid_parent_category",
1268                                                                  0,
1269                                                                  1,
1270                                                                  .0f,
1271                                                                  "valid "
1272                                                                  "name",
1273                                                                  "valid description"),
1274                       arm::pipe::InvalidArgumentException);
1275     CHECK(counterDirectory.GetCounterCount() == 0);
1276     CHECK(!noCounter);
1277 
1278     // Register a counter with an invalid name
1279     CHECK_THROWS_AS(
1280         noCounter = counterDirectory.RegisterCounter(armnn::profiling::BACKEND_ID,
1281                                                      6,
1282                                                      "valid_parent_category",
1283                                                      0,
1284                                                      1,
1285                                                      123.45f,
1286                                                      "",
1287                                                      "valid description"),
1288         arm::pipe::InvalidArgumentException);
1289     CHECK(counterDirectory.GetCounterCount() == 0);
1290     CHECK(!noCounter);
1291 
1292     // Register a counter with an invalid name
1293     CHECK_THROWS_AS(noCounter = counterDirectory.RegisterCounter(armnn::profiling::BACKEND_ID,
1294                                                                  7,
1295                                                                  "valid_parent_category",
1296                                                                  0,
1297                                                                  1,
1298                                                                  123.45f,
1299                                                                  "invalid nam€",
1300                                                                  "valid description"),
1301                       arm::pipe::InvalidArgumentException);
1302     CHECK(counterDirectory.GetCounterCount() == 0);
1303     CHECK(!noCounter);
1304 
1305     // Register a counter with an invalid description
1306     CHECK_THROWS_AS(noCounter =
1307                           counterDirectory.RegisterCounter(armnn::profiling::BACKEND_ID,
1308                                                            8,
1309                                                            "valid_parent_category",
1310                                                            0,
1311                                                            1,
1312                                                            123.45f,
1313                                                            "valid name",
1314                                                            ""),
1315                       arm::pipe::InvalidArgumentException);
1316     CHECK(counterDirectory.GetCounterCount() == 0);
1317     CHECK(!noCounter);
1318 
1319     // Register a counter with an invalid description
1320     CHECK_THROWS_AS(noCounter = counterDirectory.RegisterCounter(armnn::profiling::BACKEND_ID,
1321                                                                  9,
1322                                                                  "valid_parent_category",
1323                                                                  0,
1324                                                                  1,
1325                                                                  123.45f,
1326                                                                  "valid "
1327                                                                  "name",
1328                                                                  "inv@lid description"),
1329                       arm::pipe::InvalidArgumentException);
1330     CHECK(counterDirectory.GetCounterCount() == 0);
1331     CHECK(!noCounter);
1332 
1333     // Register a counter with an invalid unit2
1334     CHECK_THROWS_AS(noCounter = counterDirectory.RegisterCounter(armnn::profiling::BACKEND_ID,
1335                                                                  10,
1336                                                                  "valid_parent_category",
1337                                                                  0,
1338                                                                  1,
1339                                                                  123.45f,
1340                                                                  "valid name",
1341                                                                  "valid description",
1342                                                                  std::string("Mb/s2")),
1343                       arm::pipe::InvalidArgumentException);
1344     CHECK(counterDirectory.GetCounterCount() == 0);
1345     CHECK(!noCounter);
1346 
1347     // Register a counter with a non-existing parent category name
1348     CHECK_THROWS_AS(noCounter = counterDirectory.RegisterCounter(armnn::profiling::BACKEND_ID,
1349                                                                  11,
1350                                                                  "invalid_parent_category",
1351                                                                  0,
1352                                                                  1,
1353                                                                  123.45f,
1354                                                                  "valid name",
1355                                                                  "valid description"),
1356                       arm::pipe::InvalidArgumentException);
1357     CHECK(counterDirectory.GetCounterCount() == 0);
1358     CHECK(!noCounter);
1359 
1360     // Try getting an unregistered counter
1361     const Counter* unregisteredCounter = counterDirectory.GetCounter(9999);
1362     CHECK(!unregisteredCounter);
1363 
1364     // Register a category for testing
1365     const std::string categoryName = "some_category";
1366     const Category* category       = nullptr;
1367     CHECK_NOTHROW(category = counterDirectory.RegisterCategory(categoryName));
1368     CHECK(counterDirectory.GetCategoryCount() == 1);
1369     CHECK(category);
1370     CHECK(category->m_Name == categoryName);
1371     CHECK(category->m_Counters.empty());
1372 
1373     // Register a counter with a valid parent category name
1374     const Counter* counter = nullptr;
1375     CHECK_NOTHROW(
1376         counter = counterDirectory.RegisterCounter(armnn::profiling::BACKEND_ID,
1377                                                    12,
1378                                                    categoryName,
1379                                                    0,
1380                                                    1,
1381                                                    123.45f,
1382                                                    "valid name",
1383                                                    "valid description"));
1384     CHECK(counterDirectory.GetCounterCount() == 1);
1385     CHECK(counter);
1386     CHECK(counter->m_MaxCounterUid == counter->m_Uid);
1387     CHECK(counter->m_Class == 0);
1388     CHECK(counter->m_Interpolation == 1);
1389     CHECK(counter->m_Multiplier == 123.45f);
1390     CHECK(counter->m_Name == "valid name");
1391     CHECK(counter->m_Description == "valid description");
1392     CHECK(counter->m_Units == "");
1393     CHECK(counter->m_DeviceUid == 0);
1394     CHECK(counter->m_CounterSetUid == 0);
1395     CHECK(category->m_Counters.size() == 1);
1396     CHECK(category->m_Counters.back() == counter->m_Uid);
1397 
1398     // Register a counter with a name of a counter already registered for the given parent category name
1399     const Counter* counterSameName = nullptr;
1400     CHECK_THROWS_AS(counterSameName =
1401                           counterDirectory.RegisterCounter(armnn::profiling::BACKEND_ID,
1402                                                            13,
1403                                                            categoryName,
1404                                                            0,
1405                                                            0,
1406                                                            1.0f,
1407                                                            "valid name",
1408                                                            "valid description",
1409                                                            std::string("description")),
1410                       arm::pipe::InvalidArgumentException);
1411     CHECK(counterDirectory.GetCounterCount() == 1);
1412     CHECK(!counterSameName);
1413 
1414     // Register a counter with a valid parent category name and units
1415     const Counter* counterWUnits = nullptr;
1416     CHECK_NOTHROW(counterWUnits = counterDirectory.RegisterCounter(armnn::profiling::BACKEND_ID,
1417                                                                    14,
1418                                                                    categoryName,
1419                                                                    0,
1420                                                                    1,
1421                                                                    123.45f,
1422                                                                    "valid name 2",
1423                                                                    "valid description",
1424                                                                    std::string("Mnnsq2")));    // Units
1425     CHECK(counterDirectory.GetCounterCount() == 2);
1426     CHECK(counterWUnits);
1427     CHECK(counterWUnits->m_Uid > counter->m_Uid);
1428     CHECK(counterWUnits->m_MaxCounterUid == counterWUnits->m_Uid);
1429     CHECK(counterWUnits->m_Class == 0);
1430     CHECK(counterWUnits->m_Interpolation == 1);
1431     CHECK(counterWUnits->m_Multiplier == 123.45f);
1432     CHECK(counterWUnits->m_Name == "valid name 2");
1433     CHECK(counterWUnits->m_Description == "valid description");
1434     CHECK(counterWUnits->m_Units == "Mnnsq2");
1435     CHECK(counterWUnits->m_DeviceUid == 0);
1436     CHECK(counterWUnits->m_CounterSetUid == 0);
1437     CHECK(category->m_Counters.size() == 2);
1438     CHECK(category->m_Counters.back() == counterWUnits->m_Uid);
1439 
1440     // Register a counter with a valid parent category name and not associated with a device
1441     const Counter* counterWoDevice = nullptr;
1442     CHECK_NOTHROW(counterWoDevice = counterDirectory.RegisterCounter(armnn::profiling::BACKEND_ID,
1443                                                                      26,
1444                                                                      categoryName,
1445                                                                      0,
1446                                                                      1,
1447                                                                      123.45f,
1448                                                                      "valid name 3",
1449                                                                      "valid description",
1450                                                                      arm::pipe::EmptyOptional(),// Units
1451                                                                      arm::pipe::EmptyOptional(),// Number of cores
1452                                                                      0));                   // Device UID
1453     CHECK(counterDirectory.GetCounterCount() == 3);
1454     CHECK(counterWoDevice);
1455     CHECK(counterWoDevice->m_Uid > counter->m_Uid);
1456     CHECK(counterWoDevice->m_MaxCounterUid == counterWoDevice->m_Uid);
1457     CHECK(counterWoDevice->m_Class == 0);
1458     CHECK(counterWoDevice->m_Interpolation == 1);
1459     CHECK(counterWoDevice->m_Multiplier == 123.45f);
1460     CHECK(counterWoDevice->m_Name == "valid name 3");
1461     CHECK(counterWoDevice->m_Description == "valid description");
1462     CHECK(counterWoDevice->m_Units == "");
1463     CHECK(counterWoDevice->m_DeviceUid == 0);
1464     CHECK(counterWoDevice->m_CounterSetUid == 0);
1465     CHECK(category->m_Counters.size() == 3);
1466     CHECK(category->m_Counters.back() == counterWoDevice->m_Uid);
1467 
1468     // Register a counter with a valid parent category name and associated to an invalid device
1469     CHECK_THROWS_AS(noCounter = counterDirectory.RegisterCounter(armnn::profiling::BACKEND_ID,
1470                                                                 15,
1471                                                                 categoryName,
1472                                                                 0,
1473                                                                 1,
1474                                                                 123.45f,
1475                                                                 "valid name 4",
1476                                                                 "valid description",
1477                                                                 arm::pipe::EmptyOptional(),    // Units
1478                                                                 arm::pipe::EmptyOptional(),    // Number of cores
1479                                                                 100),                      // Device UID
1480                       arm::pipe::InvalidArgumentException);
1481     CHECK(counterDirectory.GetCounterCount() == 3);
1482     CHECK(!noCounter);
1483 
1484     // Register a device for testing
1485     const std::string deviceName = "some_device";
1486     const Device* device         = nullptr;
1487     CHECK_NOTHROW(device = counterDirectory.RegisterDevice(deviceName));
1488     CHECK(counterDirectory.GetDeviceCount() == 1);
1489     CHECK(device);
1490     CHECK(device->m_Name == deviceName);
1491     CHECK(device->m_Uid >= 1);
1492     CHECK(device->m_Cores == 0);
1493 
1494     // Register a counter with a valid parent category name and associated to a device
1495     const Counter* counterWDevice = nullptr;
1496     CHECK_NOTHROW(counterWDevice = counterDirectory.RegisterCounter(armnn::profiling::BACKEND_ID,
1497                                                                     16,
1498                                                                     categoryName,
1499                                                                     0,
1500                                                                     1,
1501                                                                     123.45f,
1502                                                                     "valid name 5",
1503                                                                     std::string("valid description"),
1504                                                                     arm::pipe::EmptyOptional(), // Units
1505                                                                     arm::pipe::EmptyOptional(), // Number of cores
1506                                                                     device->m_Uid));        // Device UID
1507     CHECK(counterDirectory.GetCounterCount() == 4);
1508     CHECK(counterWDevice);
1509     CHECK(counterWDevice->m_Uid > counter->m_Uid);
1510     CHECK(counterWDevice->m_MaxCounterUid == counterWDevice->m_Uid);
1511     CHECK(counterWDevice->m_Class == 0);
1512     CHECK(counterWDevice->m_Interpolation == 1);
1513     CHECK(counterWDevice->m_Multiplier == 123.45f);
1514     CHECK(counterWDevice->m_Name == "valid name 5");
1515     CHECK(counterWDevice->m_Description == "valid description");
1516     CHECK(counterWDevice->m_Units == "");
1517     CHECK(counterWDevice->m_DeviceUid == device->m_Uid);
1518     CHECK(counterWDevice->m_CounterSetUid == 0);
1519     CHECK(category->m_Counters.size() == 4);
1520     CHECK(category->m_Counters.back() == counterWDevice->m_Uid);
1521 
1522     // Register a counter with a valid parent category name and not associated with a counter set
1523     const Counter* counterWoCounterSet = nullptr;
1524     CHECK_NOTHROW(counterWoCounterSet = counterDirectory.RegisterCounter(armnn::profiling::BACKEND_ID,
1525                                                                          17,
1526                                                                          categoryName,
1527                                                                          0,
1528                                                                          1,
1529                                                                          123.45f,
1530                                                                          "valid name 6",
1531                                                                          "valid description",
1532                                                                          arm::pipe::EmptyOptional(),// Units
1533                                                                          arm::pipe::EmptyOptional(),// No of cores
1534                                                                          arm::pipe::EmptyOptional(),// Device UID
1535                                                                          0));               // CounterSet UID
1536     CHECK(counterDirectory.GetCounterCount() == 5);
1537     CHECK(counterWoCounterSet);
1538     CHECK(counterWoCounterSet->m_Uid > counter->m_Uid);
1539     CHECK(counterWoCounterSet->m_MaxCounterUid == counterWoCounterSet->m_Uid);
1540     CHECK(counterWoCounterSet->m_Class == 0);
1541     CHECK(counterWoCounterSet->m_Interpolation == 1);
1542     CHECK(counterWoCounterSet->m_Multiplier == 123.45f);
1543     CHECK(counterWoCounterSet->m_Name == "valid name 6");
1544     CHECK(counterWoCounterSet->m_Description == "valid description");
1545     CHECK(counterWoCounterSet->m_Units == "");
1546     CHECK(counterWoCounterSet->m_DeviceUid == 0);
1547     CHECK(counterWoCounterSet->m_CounterSetUid == 0);
1548     CHECK(category->m_Counters.size() == 5);
1549     CHECK(category->m_Counters.back() == counterWoCounterSet->m_Uid);
1550 
1551     // Register a counter with a valid parent category name and associated to an invalid counter set
1552     CHECK_THROWS_AS(noCounter = counterDirectory.RegisterCounter(armnn::profiling::BACKEND_ID,
1553                                                                  18,
1554                                                                  categoryName,
1555                                                                  0,
1556                                                                  1,
1557                                                                  123.45f,
1558                                                                  "valid ",
1559                                                                  "name 7",
1560                                                                  std::string("valid description"),
1561                                                                  arm::pipe::EmptyOptional(),    // Units
1562                                                                  arm::pipe::EmptyOptional(),    // Number of cores
1563                                                                  100),            // Counter set UID
1564                       arm::pipe::InvalidArgumentException);
1565     CHECK(counterDirectory.GetCounterCount() == 5);
1566     CHECK(!noCounter);
1567 
1568     // Register a counter with a valid parent category name and with a given number of cores
1569     const Counter* counterWNumberOfCores = nullptr;
1570     uint16_t numberOfCores               = 15;
1571     CHECK_NOTHROW(counterWNumberOfCores = counterDirectory.RegisterCounter(
1572                              armnn::profiling::BACKEND_ID, 50,
1573                              categoryName, 0, 1, 123.45f, "valid name 8", "valid description",
1574                              arm::pipe::EmptyOptional(),      // Units
1575                              numberOfCores,               // Number of cores
1576                              arm::pipe::EmptyOptional(),      // Device UID
1577                              arm::pipe::EmptyOptional()));    // Counter set UID
1578     CHECK(counterDirectory.GetCounterCount() == 20);
1579     CHECK(counterWNumberOfCores);
1580     CHECK(counterWNumberOfCores->m_Uid > counter->m_Uid);
1581     CHECK(counterWNumberOfCores->m_MaxCounterUid == counterWNumberOfCores->m_Uid + numberOfCores - 1);
1582     CHECK(counterWNumberOfCores->m_Class == 0);
1583     CHECK(counterWNumberOfCores->m_Interpolation == 1);
1584     CHECK(counterWNumberOfCores->m_Multiplier == 123.45f);
1585     CHECK(counterWNumberOfCores->m_Name == "valid name 8");
1586     CHECK(counterWNumberOfCores->m_Description == "valid description");
1587     CHECK(counterWNumberOfCores->m_Units == "");
1588     CHECK(counterWNumberOfCores->m_DeviceUid == 0);
1589     CHECK(counterWNumberOfCores->m_CounterSetUid == 0);
1590     CHECK(category->m_Counters.size() == 20);
1591     for (size_t i = 0; i < numberOfCores; i++)
1592     {
1593         CHECK(category->m_Counters[category->m_Counters.size() - numberOfCores + i] ==
1594                     counterWNumberOfCores->m_Uid + i);
1595     }
1596 
1597     // Register a multi-core device for testing
1598     const std::string multiCoreDeviceName = "some_multi_core_device";
1599     const Device* multiCoreDevice         = nullptr;
1600     CHECK_NOTHROW(multiCoreDevice = counterDirectory.RegisterDevice(multiCoreDeviceName, 4));
1601     CHECK(counterDirectory.GetDeviceCount() == 2);
1602     CHECK(multiCoreDevice);
1603     CHECK(multiCoreDevice->m_Name == multiCoreDeviceName);
1604     CHECK(multiCoreDevice->m_Uid >= 1);
1605     CHECK(multiCoreDevice->m_Cores == 4);
1606 
1607     // Register a counter with a valid parent category name and associated to the multi-core device
1608     const Counter* counterWMultiCoreDevice = nullptr;
1609     CHECK_NOTHROW(counterWMultiCoreDevice = counterDirectory.RegisterCounter(
1610                              armnn::profiling::BACKEND_ID, 19, categoryName, 0, 1,
1611                              123.45f, "valid name 9", "valid description",
1612                              arm::pipe::EmptyOptional(),      // Units
1613                              arm::pipe::EmptyOptional(),      // Number of cores
1614                              multiCoreDevice->m_Uid,      // Device UID
1615                              arm::pipe::EmptyOptional()));    // Counter set UID
1616     CHECK(counterDirectory.GetCounterCount() == 24);
1617     CHECK(counterWMultiCoreDevice);
1618     CHECK(counterWMultiCoreDevice->m_Uid > counter->m_Uid);
1619     CHECK(counterWMultiCoreDevice->m_MaxCounterUid ==
1620                 counterWMultiCoreDevice->m_Uid + multiCoreDevice->m_Cores - 1);
1621     CHECK(counterWMultiCoreDevice->m_Class == 0);
1622     CHECK(counterWMultiCoreDevice->m_Interpolation == 1);
1623     CHECK(counterWMultiCoreDevice->m_Multiplier == 123.45f);
1624     CHECK(counterWMultiCoreDevice->m_Name == "valid name 9");
1625     CHECK(counterWMultiCoreDevice->m_Description == "valid description");
1626     CHECK(counterWMultiCoreDevice->m_Units == "");
1627     CHECK(counterWMultiCoreDevice->m_DeviceUid == multiCoreDevice->m_Uid);
1628     CHECK(counterWMultiCoreDevice->m_CounterSetUid == 0);
1629     CHECK(category->m_Counters.size() == 24);
1630     for (size_t i = 0; i < 4; i++)
1631     {
1632         CHECK(category->m_Counters[category->m_Counters.size() - 4 + i] == counterWMultiCoreDevice->m_Uid + i);
1633     }
1634 
1635     // Register a multi-core device associate to a parent category for testing
1636     const std::string multiCoreDeviceNameWParentCategory = "some_multi_core_device_with_parent_category";
1637     const Device* multiCoreDeviceWParentCategory         = nullptr;
1638     CHECK_NOTHROW(multiCoreDeviceWParentCategory =
1639                              counterDirectory.RegisterDevice(multiCoreDeviceNameWParentCategory, 2, categoryName));
1640     CHECK(counterDirectory.GetDeviceCount() == 3);
1641     CHECK(multiCoreDeviceWParentCategory);
1642     CHECK(multiCoreDeviceWParentCategory->m_Name == multiCoreDeviceNameWParentCategory);
1643     CHECK(multiCoreDeviceWParentCategory->m_Uid >= 1);
1644     CHECK(multiCoreDeviceWParentCategory->m_Cores == 2);
1645 
1646     // Register a counter with a valid parent category name and getting the number of cores of the multi-core device
1647     // associated to that category
1648     const Counter* counterWMultiCoreDeviceWParentCategory = nullptr;
1649     uint16_t numberOfCourse = multiCoreDeviceWParentCategory->m_Cores;
1650     CHECK_NOTHROW(counterWMultiCoreDeviceWParentCategory =
1651                                                 counterDirectory.RegisterCounter(
1652                                                     armnn::profiling::BACKEND_ID,
1653                                                     100,
1654                                                     categoryName,
1655                                                     0,
1656                                                     1,
1657                                                     123.45f,
1658                                                     "valid name 10",
1659                                                     "valid description",
1660                                                     arm::pipe::EmptyOptional(),  // Units
1661                                                     numberOfCourse,          // Number of cores
1662                                                     arm::pipe::EmptyOptional(),  // Device UID
1663                                                     arm::pipe::EmptyOptional()));// Counter set UID
1664     CHECK(counterDirectory.GetCounterCount() == 26);
1665     CHECK(counterWMultiCoreDeviceWParentCategory);
1666     CHECK(counterWMultiCoreDeviceWParentCategory->m_Uid > counter->m_Uid);
1667     CHECK(counterWMultiCoreDeviceWParentCategory->m_MaxCounterUid ==
1668                 counterWMultiCoreDeviceWParentCategory->m_Uid + multiCoreDeviceWParentCategory->m_Cores - 1);
1669     CHECK(counterWMultiCoreDeviceWParentCategory->m_Class == 0);
1670     CHECK(counterWMultiCoreDeviceWParentCategory->m_Interpolation == 1);
1671     CHECK(counterWMultiCoreDeviceWParentCategory->m_Multiplier == 123.45f);
1672     CHECK(counterWMultiCoreDeviceWParentCategory->m_Name == "valid name 10");
1673     CHECK(counterWMultiCoreDeviceWParentCategory->m_Description == "valid description");
1674     CHECK(counterWMultiCoreDeviceWParentCategory->m_Units == "");
1675     CHECK(category->m_Counters.size() == 26);
1676     for (size_t i = 0; i < 2; i++)
1677     {
1678         CHECK(category->m_Counters[category->m_Counters.size() - 2 + i] ==
1679                     counterWMultiCoreDeviceWParentCategory->m_Uid + i);
1680     }
1681 
1682     // Register a counter set for testing
1683     const std::string counterSetName = "some_counter_set";
1684     const CounterSet* counterSet     = nullptr;
1685     CHECK_NOTHROW(counterSet = counterDirectory.RegisterCounterSet(counterSetName));
1686     CHECK(counterDirectory.GetCounterSetCount() == 1);
1687     CHECK(counterSet);
1688     CHECK(counterSet->m_Name == counterSetName);
1689     CHECK(counterSet->m_Uid >= 1);
1690     CHECK(counterSet->m_Count == 0);
1691 
1692     // Register a counter with a valid parent category name and associated to a counter set
1693     const Counter* counterWCounterSet = nullptr;
1694     CHECK_NOTHROW(counterWCounterSet = counterDirectory.RegisterCounter(
1695                              armnn::profiling::BACKEND_ID, 300,
1696                              categoryName, 0, 1, 123.45f, "valid name 11", "valid description",
1697                              arm::pipe::EmptyOptional(),    // Units
1698                              0,                         // Number of cores
1699                              arm::pipe::EmptyOptional(),    // Device UID
1700                              counterSet->m_Uid));       // Counter set UID
1701     CHECK(counterDirectory.GetCounterCount() == 27);
1702     CHECK(counterWCounterSet);
1703     CHECK(counterWCounterSet->m_Uid > counter->m_Uid);
1704     CHECK(counterWCounterSet->m_MaxCounterUid == counterWCounterSet->m_Uid);
1705     CHECK(counterWCounterSet->m_Class == 0);
1706     CHECK(counterWCounterSet->m_Interpolation == 1);
1707     CHECK(counterWCounterSet->m_Multiplier == 123.45f);
1708     CHECK(counterWCounterSet->m_Name == "valid name 11");
1709     CHECK(counterWCounterSet->m_Description == "valid description");
1710     CHECK(counterWCounterSet->m_Units == "");
1711     CHECK(counterWCounterSet->m_DeviceUid == 0);
1712     CHECK(counterWCounterSet->m_CounterSetUid == counterSet->m_Uid);
1713     CHECK(category->m_Counters.size() == 27);
1714     CHECK(category->m_Counters.back() == counterWCounterSet->m_Uid);
1715 
1716     // Register a counter with a valid parent category name and associated to a device and a counter set
1717     const Counter* counterWDeviceWCounterSet = nullptr;
1718     CHECK_NOTHROW(counterWDeviceWCounterSet = counterDirectory.RegisterCounter(
1719                              armnn::profiling::BACKEND_ID, 23,
1720                              categoryName, 0, 1, 123.45f, "valid name 12", "valid description",
1721                              arm::pipe::EmptyOptional(),    // Units
1722                              1,                         // Number of cores
1723                              device->m_Uid,             // Device UID
1724                              counterSet->m_Uid));       // Counter set UID
1725     CHECK(counterDirectory.GetCounterCount() == 28);
1726     CHECK(counterWDeviceWCounterSet);
1727     CHECK(counterWDeviceWCounterSet->m_Uid > counter->m_Uid);
1728     CHECK(counterWDeviceWCounterSet->m_MaxCounterUid == counterWDeviceWCounterSet->m_Uid);
1729     CHECK(counterWDeviceWCounterSet->m_Class == 0);
1730     CHECK(counterWDeviceWCounterSet->m_Interpolation == 1);
1731     CHECK(counterWDeviceWCounterSet->m_Multiplier == 123.45f);
1732     CHECK(counterWDeviceWCounterSet->m_Name == "valid name 12");
1733     CHECK(counterWDeviceWCounterSet->m_Description == "valid description");
1734     CHECK(counterWDeviceWCounterSet->m_Units == "");
1735     CHECK(counterWDeviceWCounterSet->m_DeviceUid == device->m_Uid);
1736     CHECK(counterWDeviceWCounterSet->m_CounterSetUid == counterSet->m_Uid);
1737     CHECK(category->m_Counters.size() == 28);
1738     CHECK(category->m_Counters.back() == counterWDeviceWCounterSet->m_Uid);
1739 
1740     // Register another category for testing
1741     const std::string anotherCategoryName = "some_other_category";
1742     const Category* anotherCategory       = nullptr;
1743     CHECK_NOTHROW(anotherCategory = counterDirectory.RegisterCategory(anotherCategoryName));
1744     CHECK(counterDirectory.GetCategoryCount() == 2);
1745     CHECK(anotherCategory);
1746     CHECK(anotherCategory != category);
1747     CHECK(anotherCategory->m_Name == anotherCategoryName);
1748     CHECK(anotherCategory->m_Counters.empty());
1749 
1750     // Register a counter to the other category
1751     const Counter* anotherCounter = nullptr;
1752     CHECK_NOTHROW(anotherCounter = counterDirectory.RegisterCounter(armnn::profiling::BACKEND_ID, 24,
1753                                                                     anotherCategoryName, 1, 0, .00043f,
1754                                                                     "valid name", "valid description",
1755                                                                     arm::pipe::EmptyOptional(), // Units
1756                                                                     arm::pipe::EmptyOptional(), // Number of cores
1757                                                                     device->m_Uid,          // Device UID
1758                                                                     counterSet->m_Uid));    // Counter set UID
1759     CHECK(counterDirectory.GetCounterCount() == 29);
1760     CHECK(anotherCounter);
1761     CHECK(anotherCounter->m_MaxCounterUid == anotherCounter->m_Uid);
1762     CHECK(anotherCounter->m_Class == 1);
1763     CHECK(anotherCounter->m_Interpolation == 0);
1764     CHECK(anotherCounter->m_Multiplier == .00043f);
1765     CHECK(anotherCounter->m_Name == "valid name");
1766     CHECK(anotherCounter->m_Description == "valid description");
1767     CHECK(anotherCounter->m_Units == "");
1768     CHECK(anotherCounter->m_DeviceUid == device->m_Uid);
1769     CHECK(anotherCounter->m_CounterSetUid == counterSet->m_Uid);
1770     CHECK(anotherCategory->m_Counters.size() == 1);
1771     CHECK(anotherCategory->m_Counters.back() == anotherCounter->m_Uid);
1772 }
1773 
1774 TEST_CASE("CounterSelectionCommandHandlerParseData")
1775 {
1776     ProfilingStateMachine profilingStateMachine;
1777 
1778     class TestCaptureThread : public IPeriodicCounterCapture
1779     {
Start()1780         void Start() override
1781         {}
Stop()1782         void Stop() override
1783         {}
1784     };
1785 
1786     class TestReadCounterValues : public IReadCounterValues
1787     {
IsCounterRegistered(uint16_t counterUid) const1788         bool IsCounterRegistered(uint16_t counterUid) const override
1789         {
1790             arm::pipe::IgnoreUnused(counterUid);
1791             return true;
1792         }
IsCounterRegistered(const std::string & counterName) const1793         bool IsCounterRegistered(const std::string& counterName) const override
1794         {
1795             arm::pipe::IgnoreUnused(counterName);
1796             return true;
1797         }
GetCounterCount() const1798         uint16_t GetCounterCount() const override
1799         {
1800             return 0;
1801         }
GetAbsoluteCounterValue(uint16_t counterUid) const1802         uint32_t GetAbsoluteCounterValue(uint16_t counterUid) const override
1803         {
1804             arm::pipe::IgnoreUnused(counterUid);
1805             return 0;
1806         }
GetDeltaCounterValue(uint16_t counterUid)1807         uint32_t GetDeltaCounterValue(uint16_t counterUid) override
1808         {
1809             arm::pipe::IgnoreUnused(counterUid);
1810             return 0;
1811         }
1812     };
1813     const uint32_t familyId = 0;
1814     const uint32_t packetId = 0x40000;
1815 
1816     uint32_t version = 1;
1817     const std::unordered_map<std::string,
1818             std::shared_ptr<IBackendProfilingContext>> backendProfilingContext;
1819     CounterIdMap counterIdMap;
1820     Holder holder;
1821     TestCaptureThread captureThread;
1822     TestReadCounterValues readCounterValues;
1823     MockBufferManager mockBuffer(512);
1824     SendCounterPacket sendCounterPacket(mockBuffer,
1825                                         arm::pipe::ARMNN_SOFTWARE_INFO,
1826                                         arm::pipe::ARMNN_SOFTWARE_VERSION,
1827                                         arm::pipe::ARMNN_HARDWARE_VERSION);
1828     SendThread sendThread(profilingStateMachine, mockBuffer, sendCounterPacket);
1829 
1830     uint32_t sizeOfUint32 = arm::pipe::numeric_cast<uint32_t>(sizeof(uint32_t));
1831     uint32_t sizeOfUint16 = arm::pipe::numeric_cast<uint32_t>(sizeof(uint16_t));
1832 
1833     // Data with period and counters
1834     uint32_t period1     = arm::pipe::LOWEST_CAPTURE_PERIOD;
1835     uint32_t dataLength1 = 8;
1836     uint32_t offset      = 0;
1837 
1838     std::unique_ptr<unsigned char[]> uniqueData1 = std::make_unique<unsigned char[]>(dataLength1);
1839     unsigned char* data1                         = reinterpret_cast<unsigned char*>(uniqueData1.get());
1840 
1841     WriteUint32(data1, offset, period1);
1842     offset += sizeOfUint32;
1843     WriteUint16(data1, offset, 4000);
1844     offset += sizeOfUint16;
1845     WriteUint16(data1, offset, 5000);
1846 
1847     arm::pipe::Packet packetA(packetId, dataLength1, uniqueData1);
1848 
1849     PeriodicCounterSelectionCommandHandler commandHandler(familyId, packetId, version, backendProfilingContext,
1850                                                           counterIdMap, holder, 10000u, captureThread,
1851                                                           readCounterValues, sendCounterPacket, profilingStateMachine);
1852 
1853     profilingStateMachine.TransitionToState(ProfilingState::Uninitialised);
1854     CHECK_THROWS_AS(commandHandler(packetA), arm::pipe::ProfilingException);
1855     profilingStateMachine.TransitionToState(ProfilingState::NotConnected);
1856     CHECK_THROWS_AS(commandHandler(packetA), arm::pipe::ProfilingException);
1857     profilingStateMachine.TransitionToState(ProfilingState::WaitingForAck);
1858     CHECK_THROWS_AS(commandHandler(packetA), arm::pipe::ProfilingException);
1859     profilingStateMachine.TransitionToState(ProfilingState::Active);
1860     CHECK_NOTHROW(commandHandler(packetA));
1861 
1862     const std::vector<uint16_t> counterIdsA = holder.GetCaptureData().GetCounterIds();
1863 
1864     CHECK(holder.GetCaptureData().GetCapturePeriod() == period1);
1865     CHECK(counterIdsA.size() == 2);
1866     CHECK(counterIdsA[0] == 4000);
1867     CHECK(counterIdsA[1] == 5000);
1868 
1869     auto readBuffer = mockBuffer.GetReadableBuffer();
1870 
1871     offset = 0;
1872 
1873     uint32_t headerWord0 = ReadUint32(readBuffer, offset);
1874     offset += sizeOfUint32;
1875     uint32_t headerWord1 = ReadUint32(readBuffer, offset);
1876     offset += sizeOfUint32;
1877     uint32_t period = ReadUint32(readBuffer, offset);
1878 
1879     CHECK(((headerWord0 >> 26) & 0x3F) == 0);             // packet family
1880     CHECK(((headerWord0 >> 16) & 0x3FF) == 4);            // packet id
1881     CHECK(headerWord1 == 8);                              // data length
1882     CHECK(period ==  arm::pipe::LOWEST_CAPTURE_PERIOD);   // capture period
1883 
1884     uint16_t counterId = 0;
1885     offset += sizeOfUint32;
1886     counterId = ReadUint16(readBuffer, offset);
1887     CHECK(counterId == 4000);
1888     offset += sizeOfUint16;
1889     counterId = ReadUint16(readBuffer, offset);
1890     CHECK(counterId == 5000);
1891 
1892     mockBuffer.MarkRead(readBuffer);
1893 
1894     // Data with period only
1895     uint32_t period2     = 9000; // We'll specify a value below LOWEST_CAPTURE_PERIOD. It should be pulled upwards.
1896     uint32_t dataLength2 = 4;
1897 
1898     std::unique_ptr<unsigned char[]> uniqueData2 = std::make_unique<unsigned char[]>(dataLength2);
1899 
1900     WriteUint32(reinterpret_cast<unsigned char*>(uniqueData2.get()), 0, period2);
1901 
1902     arm::pipe::Packet packetB(packetId, dataLength2, uniqueData2);
1903 
1904     commandHandler(packetB);
1905 
1906     const std::vector<uint16_t> counterIdsB = holder.GetCaptureData().GetCounterIds();
1907 
1908     // Value should have been pulled up from 9000 to LOWEST_CAPTURE_PERIOD.
1909     CHECK(holder.GetCaptureData().GetCapturePeriod() ==  arm::pipe::LOWEST_CAPTURE_PERIOD);
1910     CHECK(counterIdsB.size() == 0);
1911 
1912     readBuffer = mockBuffer.GetReadableBuffer();
1913 
1914     offset = 0;
1915 
1916     headerWord0 = ReadUint32(readBuffer, offset);
1917     offset += sizeOfUint32;
1918     headerWord1 = ReadUint32(readBuffer, offset);
1919     offset += sizeOfUint32;
1920     period = ReadUint32(readBuffer, offset);
1921 
1922     CHECK(((headerWord0 >> 26) & 0x3F) == 0);          // packet family
1923     CHECK(((headerWord0 >> 16) & 0x3FF) == 4);         // packet id
1924     CHECK(headerWord1 == 4);                           // data length
1925     CHECK(period == arm::pipe::LOWEST_CAPTURE_PERIOD); // capture period
1926 }
1927 
1928 TEST_CASE("CheckTimelineActivationAndDeactivation")
1929 {
1930     class TestReportStructure : public IReportStructure
1931     {
1932         public:
ReportStructure(arm::pipe::IProfilingService & profilingService)1933         virtual void ReportStructure(arm::pipe::IProfilingService& profilingService) override
1934         {
1935             arm::pipe::IgnoreUnused(profilingService);
1936             m_ReportStructureCalled = true;
1937         }
1938 
1939         bool m_ReportStructureCalled = false;
1940     };
1941 
1942     class TestNotifyBackends : public INotifyBackends
1943     {
1944         public:
TestNotifyBackends()1945         TestNotifyBackends() : m_timelineReporting(false) {}
NotifyBackendsForTimelineReporting()1946         virtual void NotifyBackendsForTimelineReporting() override
1947         {
1948             m_TestNotifyBackendsCalled = m_timelineReporting.load();
1949         }
1950 
1951         bool m_TestNotifyBackendsCalled = false;
1952         std::atomic<bool> m_timelineReporting;
1953     };
1954 
1955     arm::pipe::PacketVersionResolver packetVersionResolver;
1956 
1957     BufferManager bufferManager(512);
1958     SendTimelinePacket sendTimelinePacket(bufferManager);
1959     ProfilingStateMachine stateMachine;
1960     TestReportStructure testReportStructure;
1961     TestNotifyBackends testNotifyBackends;
1962     armnn::ArmNNProfilingServiceInitialiser initialiser;
1963     ProfilingService profilingService(arm::pipe::MAX_ARMNN_COUNTER,
1964                                       initialiser,
1965                                       arm::pipe::ARMNN_SOFTWARE_INFO,
1966                                       arm::pipe::ARMNN_SOFTWARE_VERSION,
1967                                       arm::pipe::ARMNN_HARDWARE_VERSION);
1968 
1969 
1970     ActivateTimelineReportingCommandHandler activateTimelineReportingCommandHandler(0,
1971                                                            6,
1972                                                            packetVersionResolver.ResolvePacketVersion(0, 6)
1973                                                            .GetEncodedValue(),
1974                                                            sendTimelinePacket,
1975                                                            stateMachine,
1976                                                            testReportStructure,
1977                                                            testNotifyBackends.m_timelineReporting,
1978                                                            testNotifyBackends,
1979                                                            profilingService);
1980 
1981     // Write an "ActivateTimelineReporting" packet into the mock profiling connection, to simulate an input from an
1982     // external profiling service
1983     const uint32_t packetFamily1 = 0;
1984     const uint32_t packetId1     = 6;
1985     uint32_t packetHeader1 = ConstructHeader(packetFamily1, packetId1);
1986 
1987     // Create the ActivateTimelineReportingPacket
1988     arm::pipe::Packet ActivateTimelineReportingPacket(packetHeader1); // Length == 0
1989 
1990     CHECK_THROWS_AS(
1991         activateTimelineReportingCommandHandler.operator()(ActivateTimelineReportingPacket),
1992                                                            arm::pipe::ProfilingException);
1993 
1994     stateMachine.TransitionToState(ProfilingState::NotConnected);
1995     CHECK_THROWS_AS(
1996         activateTimelineReportingCommandHandler.operator()(ActivateTimelineReportingPacket),
1997                                                            arm::pipe::ProfilingException);
1998 
1999     stateMachine.TransitionToState(ProfilingState::WaitingForAck);
2000     CHECK_THROWS_AS(
2001         activateTimelineReportingCommandHandler.operator()(ActivateTimelineReportingPacket),
2002                                                            arm::pipe::ProfilingException);
2003 
2004     stateMachine.TransitionToState(ProfilingState::Active);
2005     activateTimelineReportingCommandHandler.operator()(ActivateTimelineReportingPacket);
2006 
2007     CHECK(testReportStructure.m_ReportStructureCalled);
2008     CHECK(testNotifyBackends.m_TestNotifyBackendsCalled);
2009     CHECK(testNotifyBackends.m_timelineReporting.load());
2010 
2011     DeactivateTimelineReportingCommandHandler deactivateTimelineReportingCommandHandler(0,
2012                                                   7,
2013                                                   packetVersionResolver.ResolvePacketVersion(0, 7).GetEncodedValue(),
2014                                                   testNotifyBackends.m_timelineReporting,
2015                                                   stateMachine,
2016                                                   testNotifyBackends);
2017 
2018     const uint32_t packetFamily2 = 0;
2019     const uint32_t packetId2     = 7;
2020     uint32_t packetHeader2 = ConstructHeader(packetFamily2, packetId2);
2021 
2022     // Create the DeactivateTimelineReportingPacket
2023     arm::pipe::Packet deactivateTimelineReportingPacket(packetHeader2); // Length == 0
2024 
2025     stateMachine.Reset();
2026     CHECK_THROWS_AS(
2027         deactivateTimelineReportingCommandHandler.operator()(deactivateTimelineReportingPacket),
2028                                                              arm::pipe::ProfilingException);
2029 
2030     stateMachine.TransitionToState(ProfilingState::NotConnected);
2031     CHECK_THROWS_AS(
2032         deactivateTimelineReportingCommandHandler.operator()(deactivateTimelineReportingPacket),
2033                                                              arm::pipe::ProfilingException);
2034 
2035     stateMachine.TransitionToState(ProfilingState::WaitingForAck);
2036     CHECK_THROWS_AS(
2037         deactivateTimelineReportingCommandHandler.operator()(deactivateTimelineReportingPacket),
2038                                                              arm::pipe::ProfilingException);
2039 
2040     stateMachine.TransitionToState(ProfilingState::Active);
2041     deactivateTimelineReportingCommandHandler.operator()(deactivateTimelineReportingPacket);
2042 
2043     CHECK(!testNotifyBackends.m_TestNotifyBackendsCalled);
2044     CHECK(!testNotifyBackends.m_timelineReporting.load());
2045 }
2046 
2047 TEST_CASE("CheckProfilingServiceNotActive")
2048 {
2049     using namespace armnn;
2050 
2051     // Create runtime in which the test will run
2052     armnn::IRuntime::CreationOptions options;
2053     options.m_ProfilingOptions.m_EnableProfiling = true;
2054 
2055     armnn::RuntimeImpl runtime(options);
2056     armnn::ArmNNProfilingServiceInitialiser initialiser;
2057     ProfilingServiceRuntimeHelper profilingServiceHelper(
2058         arm::pipe::MAX_ARMNN_COUNTER, initialiser, GetProfilingService(&runtime));
2059     profilingServiceHelper.ForceTransitionToState(ProfilingState::NotConnected);
2060     profilingServiceHelper.ForceTransitionToState(ProfilingState::WaitingForAck);
2061     profilingServiceHelper.ForceTransitionToState(ProfilingState::Active);
2062 
2063     BufferManager& bufferManager = profilingServiceHelper.GetProfilingBufferManager();
2064     auto readableBuffer = bufferManager.GetReadableBuffer();
2065 
2066     // Profiling is enabled, the post-optimisation structure should be created
2067     CHECK(readableBuffer == nullptr);
2068 }
2069 
2070 TEST_CASE("CheckConnectionAcknowledged")
2071 {
2072     const uint32_t packetFamilyId     = 0;
2073     const uint32_t connectionPacketId = 0x10000;
2074     const uint32_t version            = 1;
2075 
2076     uint32_t sizeOfUint32 = arm::pipe::numeric_cast<uint32_t>(sizeof(uint32_t));
2077     uint32_t sizeOfUint16 = arm::pipe::numeric_cast<uint32_t>(sizeof(uint16_t));
2078 
2079     // Data with period and counters
2080     uint32_t period1     = 10;
2081     uint32_t dataLength1 = 8;
2082     uint32_t offset      = 0;
2083 
2084     std::unique_ptr<unsigned char[]> uniqueData1 = std::make_unique<unsigned char[]>(dataLength1);
2085     unsigned char* data1                         = reinterpret_cast<unsigned char*>(uniqueData1.get());
2086 
2087     WriteUint32(data1, offset, period1);
2088     offset += sizeOfUint32;
2089     WriteUint16(data1, offset, 4000);
2090     offset += sizeOfUint16;
2091     WriteUint16(data1, offset, 5000);
2092 
2093     arm::pipe::Packet packetA(connectionPacketId, dataLength1, uniqueData1);
2094 
2095     ProfilingStateMachine profilingState(ProfilingState::Uninitialised);
2096     CHECK(profilingState.GetCurrentState() == ProfilingState::Uninitialised);
2097     CounterDirectory counterDirectory;
2098     MockBufferManager mockBuffer(1024);
2099     SendCounterPacket sendCounterPacket(mockBuffer,
2100                                         arm::pipe::ARMNN_SOFTWARE_INFO,
2101                                         arm::pipe::ARMNN_SOFTWARE_VERSION,
2102                                         arm::pipe::ARMNN_HARDWARE_VERSION);
2103     SendThread sendThread(profilingState, mockBuffer, sendCounterPacket);
2104     SendTimelinePacket sendTimelinePacket(mockBuffer);
2105     MockProfilingServiceStatus mockProfilingServiceStatus;
2106 
2107     ConnectionAcknowledgedCommandHandler commandHandler(packetFamilyId,
2108                                                         connectionPacketId,
2109                                                         version,
2110                                                         counterDirectory,
2111                                                         sendCounterPacket,
2112                                                         sendTimelinePacket,
2113                                                         profilingState,
2114                                                         mockProfilingServiceStatus);
2115 
2116     // command handler received packet on ProfilingState::Uninitialised
2117     CHECK_THROWS_AS(commandHandler(packetA), arm::pipe::ProfilingException);
2118 
2119     profilingState.TransitionToState(ProfilingState::NotConnected);
2120     CHECK(profilingState.GetCurrentState() == ProfilingState::NotConnected);
2121     // command handler received packet on ProfilingState::NotConnected
2122     CHECK_THROWS_AS(commandHandler(packetA), arm::pipe::ProfilingException);
2123 
2124     profilingState.TransitionToState(ProfilingState::WaitingForAck);
2125     CHECK(profilingState.GetCurrentState() == ProfilingState::WaitingForAck);
2126     // command handler received packet on ProfilingState::WaitingForAck
2127     CHECK_NOTHROW(commandHandler(packetA));
2128     CHECK(profilingState.GetCurrentState() == ProfilingState::Active);
2129 
2130     // command handler received packet on ProfilingState::Active
2131     CHECK_NOTHROW(commandHandler(packetA));
2132     CHECK(profilingState.GetCurrentState() == ProfilingState::Active);
2133 
2134     // command handler received different packet
2135     const uint32_t differentPacketId = 0x40000;
2136     arm::pipe::Packet packetB(differentPacketId, dataLength1, uniqueData1);
2137     profilingState.TransitionToState(ProfilingState::NotConnected);
2138     profilingState.TransitionToState(ProfilingState::WaitingForAck);
2139     ConnectionAcknowledgedCommandHandler differentCommandHandler(packetFamilyId,
2140                                                                  differentPacketId,
2141                                                                  version,
2142                                                                  counterDirectory,
2143                                                                  sendCounterPacket,
2144                                                                  sendTimelinePacket,
2145                                                                  profilingState,
2146                                                                  mockProfilingServiceStatus);
2147     CHECK_THROWS_AS(differentCommandHandler(packetB), arm::pipe::ProfilingException);
2148 }
2149 
2150 TEST_CASE("CheckSocketConnectionException")
2151 {
2152     // Check that creating a SocketProfilingConnection armnnProfiling in an exception as the Gator UDS doesn't exist.
2153     CHECK_THROWS_AS(new SocketProfilingConnection(), arm::pipe::SocketConnectionException);
2154 }
2155 
2156 TEST_CASE("CheckSocketConnectionException2")
2157 {
2158     try
2159     {
2160         new SocketProfilingConnection();
2161     }
2162     catch (const arm::pipe::SocketConnectionException& ex)
2163     {
2164         CHECK(ex.GetSocketFd() == 0);
2165         CHECK(ex.GetErrorNo() == ECONNREFUSED);
2166         CHECK(ex.what()
2167                     == std::string("SocketProfilingConnection: Cannot connect to stream socket: Connection refused"));
2168     }
2169 }
2170 
2171 TEST_CASE("SwTraceIsValidCharTest")
2172 {
2173     // Only ASCII 7-bit encoding supported
2174     for (unsigned char c = 0; c < 128; c++)
2175     {
2176         CHECK(arm::pipe::SwTraceCharPolicy::IsValidChar(c));
2177     }
2178 
2179     // Not ASCII
2180     for (unsigned char c = 255; c >= 128; c++)
2181     {
2182         CHECK(!arm::pipe::SwTraceCharPolicy::IsValidChar(c));
2183     }
2184 }
2185 
2186 TEST_CASE("SwTraceIsValidNameCharTest")
2187 {
2188     // Only alpha-numeric and underscore ASCII 7-bit encoding supported
2189     const unsigned char validChars[] = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_";
2190     for (unsigned char i = 0; i < sizeof(validChars) / sizeof(validChars[0]) - 1; i++)
2191     {
2192         CHECK(arm::pipe::SwTraceNameCharPolicy::IsValidChar(validChars[i]));
2193     }
2194 
2195     // Non alpha-numeric chars
2196     for (unsigned char c = 0; c < 48; c++)
2197     {
2198         CHECK(!arm::pipe::SwTraceNameCharPolicy::IsValidChar(c));
2199     }
2200     for (unsigned char c = 58; c < 65; c++)
2201     {
2202         CHECK(!arm::pipe::SwTraceNameCharPolicy::IsValidChar(c));
2203     }
2204     for (unsigned char c = 91; c < 95; c++)
2205     {
2206         CHECK(!arm::pipe::SwTraceNameCharPolicy::IsValidChar(c));
2207     }
2208     for (unsigned char c = 96; c < 97; c++)
2209     {
2210         CHECK(!arm::pipe::SwTraceNameCharPolicy::IsValidChar(c));
2211     }
2212     for (unsigned char c = 123; c < 128; c++)
2213     {
2214         CHECK(!arm::pipe::SwTraceNameCharPolicy::IsValidChar(c));
2215     }
2216 
2217     // Not ASCII
2218     for (unsigned char c = 255; c >= 128; c++)
2219     {
2220         CHECK(!arm::pipe::SwTraceNameCharPolicy::IsValidChar(c));
2221     }
2222 }
2223 
2224 TEST_CASE("IsValidSwTraceStringTest")
2225 {
2226     // Valid SWTrace strings
2227     CHECK(arm::pipe::IsValidSwTraceString<arm::pipe::SwTraceCharPolicy>(""));
2228     CHECK(arm::pipe::IsValidSwTraceString<arm::pipe::SwTraceCharPolicy>("_"));
2229     CHECK(arm::pipe::IsValidSwTraceString<arm::pipe::SwTraceCharPolicy>("0123"));
2230     CHECK(arm::pipe::IsValidSwTraceString<arm::pipe::SwTraceCharPolicy>("valid_string"));
2231     CHECK(arm::pipe::IsValidSwTraceString<arm::pipe::SwTraceCharPolicy>("VALID_string_456"));
2232     CHECK(arm::pipe::IsValidSwTraceString<arm::pipe::SwTraceCharPolicy>(" "));
2233     CHECK(arm::pipe::IsValidSwTraceString<arm::pipe::SwTraceCharPolicy>("valid string"));
2234     CHECK(arm::pipe::IsValidSwTraceString<arm::pipe::SwTraceCharPolicy>("!$%"));
2235     CHECK(arm::pipe::IsValidSwTraceString<arm::pipe::SwTraceCharPolicy>("valid|\\~string#123"));
2236 
2237     // Invalid SWTrace strings
2238     CHECK(!arm::pipe::IsValidSwTraceString<arm::pipe::SwTraceCharPolicy>("€£"));
2239     CHECK(!arm::pipe::IsValidSwTraceString<arm::pipe::SwTraceCharPolicy>("invalid‡string"));
2240     CHECK(!arm::pipe::IsValidSwTraceString<arm::pipe::SwTraceCharPolicy>("12Ž34"));
2241 }
2242 
2243 TEST_CASE("IsValidSwTraceNameStringTest")
2244 {
2245     // Valid SWTrace name strings
2246     CHECK(arm::pipe::IsValidSwTraceString<arm::pipe::SwTraceNameCharPolicy>(""));
2247     CHECK(arm::pipe::IsValidSwTraceString<arm::pipe::SwTraceNameCharPolicy>("_"));
2248     CHECK(arm::pipe::IsValidSwTraceString<arm::pipe::SwTraceNameCharPolicy>("0123"));
2249     CHECK(arm::pipe::IsValidSwTraceString<arm::pipe::SwTraceNameCharPolicy>("valid_string"));
2250     CHECK(arm::pipe::IsValidSwTraceString<arm::pipe::SwTraceNameCharPolicy>("VALID_string_456"));
2251 
2252     // Invalid SWTrace name strings
2253     CHECK(!arm::pipe::IsValidSwTraceString<arm::pipe::SwTraceNameCharPolicy>(" "));
2254     CHECK(!arm::pipe::IsValidSwTraceString<arm::pipe::SwTraceNameCharPolicy>("invalid string"));
2255     CHECK(!arm::pipe::IsValidSwTraceString<arm::pipe::SwTraceNameCharPolicy>("!$%"));
2256     CHECK(!arm::pipe::IsValidSwTraceString<arm::pipe::SwTraceNameCharPolicy>("invalid|\\~string#123"));
2257     CHECK(!arm::pipe::IsValidSwTraceString<arm::pipe::SwTraceNameCharPolicy>("€£"));
2258     CHECK(!arm::pipe::IsValidSwTraceString<arm::pipe::SwTraceNameCharPolicy>("invalid‡string"));
2259     CHECK(!arm::pipe::IsValidSwTraceString<arm::pipe::SwTraceNameCharPolicy>("12Ž34"));
2260 }
2261 
2262 template <typename SwTracePolicy>
StringToSwTraceStringTestHelper(const std::string & testString,std::vector<uint32_t> buffer,size_t expectedSize)2263 void StringToSwTraceStringTestHelper(const std::string& testString, std::vector<uint32_t> buffer, size_t expectedSize)
2264 {
2265     // Convert the test string to a SWTrace string
2266     CHECK(arm::pipe::StringToSwTraceString<SwTracePolicy>(testString, buffer));
2267 
2268     // The buffer must contain at least the length of the string
2269     CHECK(!buffer.empty());
2270 
2271     // The buffer must be of the expected size (in words)
2272     CHECK(buffer.size() == expectedSize);
2273 
2274     // The first word of the byte must be the length of the string including the null-terminator
2275     CHECK(buffer[0] == testString.size() + 1);
2276 
2277     // The contents of the buffer must match the test string
2278     CHECK(std::memcmp(testString.data(), buffer.data() + 1, testString.size()) == 0);
2279 
2280     // The buffer must include the null-terminator at the end of the string
2281     size_t nullTerminatorIndex = sizeof(uint32_t) + testString.size();
2282     CHECK(reinterpret_cast<unsigned char*>(buffer.data())[nullTerminatorIndex] == '\0');
2283 }
2284 
2285 TEST_CASE("StringToSwTraceStringTest")
2286 {
2287     std::vector<uint32_t> buffer;
2288 
2289     // Valid SWTrace strings (expected size in words)
2290     StringToSwTraceStringTestHelper<arm::pipe::SwTraceCharPolicy>("", buffer, 2);
2291     StringToSwTraceStringTestHelper<arm::pipe::SwTraceCharPolicy>("_", buffer, 2);
2292     StringToSwTraceStringTestHelper<arm::pipe::SwTraceCharPolicy>("0123", buffer, 3);
2293     StringToSwTraceStringTestHelper<arm::pipe::SwTraceCharPolicy>("valid_string", buffer, 5);
2294     StringToSwTraceStringTestHelper<arm::pipe::SwTraceCharPolicy>("VALID_string_456", buffer, 6);
2295     StringToSwTraceStringTestHelper<arm::pipe::SwTraceCharPolicy>(" ", buffer, 2);
2296     StringToSwTraceStringTestHelper<arm::pipe::SwTraceCharPolicy>("valid string", buffer, 5);
2297     StringToSwTraceStringTestHelper<arm::pipe::SwTraceCharPolicy>("!$%", buffer, 2);
2298     StringToSwTraceStringTestHelper<arm::pipe::SwTraceCharPolicy>("valid|\\~string#123", buffer, 6);
2299 
2300     // Invalid SWTrace strings
2301     CHECK(!arm::pipe::StringToSwTraceString<arm::pipe::SwTraceCharPolicy>("€£", buffer));
2302     CHECK(buffer.empty());
2303     CHECK(!arm::pipe::StringToSwTraceString<arm::pipe::SwTraceCharPolicy>("invalid‡string", buffer));
2304     CHECK(buffer.empty());
2305     CHECK(!arm::pipe::StringToSwTraceString<arm::pipe::SwTraceCharPolicy>("12Ž34", buffer));
2306     CHECK(buffer.empty());
2307 }
2308 
2309 TEST_CASE("StringToSwTraceNameStringTest")
2310 {
2311     std::vector<uint32_t> buffer;
2312 
2313     // Valid SWTrace namestrings (expected size in words)
2314     StringToSwTraceStringTestHelper<arm::pipe::SwTraceNameCharPolicy>("", buffer, 2);
2315     StringToSwTraceStringTestHelper<arm::pipe::SwTraceNameCharPolicy>("_", buffer, 2);
2316     StringToSwTraceStringTestHelper<arm::pipe::SwTraceNameCharPolicy>("0123", buffer, 3);
2317     StringToSwTraceStringTestHelper<arm::pipe::SwTraceNameCharPolicy>("valid_string", buffer, 5);
2318     StringToSwTraceStringTestHelper<arm::pipe::SwTraceNameCharPolicy>("VALID_string_456", buffer, 6);
2319 
2320     // Invalid SWTrace namestrings
2321     CHECK(!arm::pipe::StringToSwTraceString<arm::pipe::SwTraceNameCharPolicy>(" ", buffer));
2322     CHECK(buffer.empty());
2323     CHECK(!arm::pipe::StringToSwTraceString<arm::pipe::SwTraceNameCharPolicy>("invalid string", buffer));
2324     CHECK(buffer.empty());
2325     CHECK(!arm::pipe::StringToSwTraceString<arm::pipe::SwTraceNameCharPolicy>("!$%", buffer));
2326     CHECK(buffer.empty());
2327     CHECK(!arm::pipe::StringToSwTraceString<arm::pipe::SwTraceNameCharPolicy>("invalid|\\~string#123", buffer));
2328     CHECK(buffer.empty());
2329     CHECK(!arm::pipe::StringToSwTraceString<arm::pipe::SwTraceNameCharPolicy>("€£", buffer));
2330     CHECK(buffer.empty());
2331     CHECK(!arm::pipe::StringToSwTraceString<arm::pipe::SwTraceNameCharPolicy>("invalid‡string", buffer));
2332     CHECK(buffer.empty());
2333     CHECK(!arm::pipe::StringToSwTraceString<arm::pipe::SwTraceNameCharPolicy>("12Ž34", buffer));
2334     CHECK(buffer.empty());
2335 }
2336 
2337 TEST_CASE("CheckPeriodicCounterCaptureThread")
2338 {
2339     class CaptureReader : public IReadCounterValues
2340     {
2341     public:
CaptureReader(uint16_t counterSize)2342         CaptureReader(uint16_t counterSize)
2343         {
2344             for (uint16_t i = 0; i < counterSize; ++i)
2345             {
2346                 m_Data[i] = 0;
2347             }
2348             m_CounterSize = counterSize;
2349         }
2350         //not used
IsCounterRegistered(uint16_t counterUid) const2351         bool IsCounterRegistered(uint16_t counterUid) const override
2352         {
2353             arm::pipe::IgnoreUnused(counterUid);
2354             return false;
2355         }
IsCounterRegistered(const std::string & counterName) const2356         bool IsCounterRegistered(const std::string& counterName) const override
2357         {
2358             arm::pipe::IgnoreUnused(counterName);
2359             return false;
2360         }
GetCounterCount() const2361         uint16_t GetCounterCount() const override
2362         {
2363             return m_CounterSize;
2364         }
2365 
GetAbsoluteCounterValue(uint16_t counterUid) const2366         uint32_t GetAbsoluteCounterValue(uint16_t counterUid) const override
2367         {
2368             if (counterUid > m_CounterSize)
2369             {
2370                 FAIL("Invalid counter Uid");
2371             }
2372             return m_Data.at(counterUid).load();
2373         }
2374 
GetDeltaCounterValue(uint16_t counterUid)2375         uint32_t GetDeltaCounterValue(uint16_t counterUid)  override
2376         {
2377             if (counterUid > m_CounterSize)
2378             {
2379                 FAIL("Invalid counter Uid");
2380             }
2381             return m_Data.at(counterUid).load();
2382         }
2383 
SetCounterValue(uint16_t counterUid,uint32_t value)2384         void SetCounterValue(uint16_t counterUid, uint32_t value)
2385         {
2386             if (counterUid > m_CounterSize)
2387             {
2388                 FAIL("Invalid counter Uid");
2389             }
2390             m_Data.at(counterUid).store(value);
2391         }
2392 
2393     private:
2394         std::unordered_map<uint16_t, std::atomic<uint32_t>> m_Data;
2395         uint16_t m_CounterSize;
2396     };
2397 
2398     ProfilingStateMachine profilingStateMachine;
2399 
2400     const std::unordered_map<std::string,
2401             std::shared_ptr<IBackendProfilingContext>> backendProfilingContext;
2402     CounterIdMap counterIdMap;
2403     Holder data;
2404     std::vector<uint16_t> captureIds1 = { 0, 1 };
2405     std::vector<uint16_t> captureIds2;
2406 
2407     MockBufferManager mockBuffer(512);
2408     SendCounterPacket sendCounterPacket(mockBuffer,
2409                                         arm::pipe::ARMNN_SOFTWARE_INFO,
2410                                         arm::pipe::ARMNN_SOFTWARE_VERSION,
2411                                         arm::pipe::ARMNN_HARDWARE_VERSION);
2412     SendThread sendThread(profilingStateMachine, mockBuffer, sendCounterPacket);
2413 
2414     std::vector<uint16_t> counterIds;
2415     CaptureReader captureReader(2);
2416 
2417     unsigned int valueA   = 10;
2418     unsigned int valueB   = 15;
2419     unsigned int numSteps = 5;
2420 
2421     PeriodicCounterCapture periodicCounterCapture(std::ref(data), std::ref(sendCounterPacket), captureReader,
2422                                                   counterIdMap, backendProfilingContext);
2423 
2424     for (unsigned int i = 0; i < numSteps; ++i)
2425     {
2426         data.SetCaptureData(1, captureIds1, {});
2427         captureReader.SetCounterValue(0, valueA * (i + 1));
2428         captureReader.SetCounterValue(1, valueB * (i + 1));
2429 
2430         periodicCounterCapture.Start();
2431         periodicCounterCapture.Stop();
2432     }
2433 
2434     auto buffer = mockBuffer.GetReadableBuffer();
2435 
2436     uint32_t headerWord0 = ReadUint32(buffer, 0);
2437     uint32_t headerWord1 = ReadUint32(buffer, 4);
2438 
2439     CHECK(((headerWord0 >> 26) & 0x0000003F) == 3);    // packet family
2440     CHECK(((headerWord0 >> 19) & 0x0000007F) == 0);    // packet class
2441     CHECK(((headerWord0 >> 16) & 0x00000007) == 0);    // packet type
2442     CHECK(headerWord1 == 20);
2443 
2444     uint32_t offset    = 16;
2445     uint16_t readIndex = ReadUint16(buffer, offset);
2446     CHECK(0 == readIndex);
2447 
2448     offset += 2;
2449     uint32_t readValue = ReadUint32(buffer, offset);
2450     CHECK((valueA * numSteps) == readValue);
2451 
2452     offset += 4;
2453     readIndex = ReadUint16(buffer, offset);
2454     CHECK(1 == readIndex);
2455 
2456     offset += 2;
2457     readValue = ReadUint32(buffer, offset);
2458     CHECK((valueB * numSteps) == readValue);
2459 }
2460 
2461 TEST_CASE("RequestCounterDirectoryCommandHandlerTest1")
2462 {
2463     const uint32_t familyId = 0;
2464     const uint32_t packetId = 3;
2465     const uint32_t version  = 1;
2466     ProfilingStateMachine profilingStateMachine;
2467     CounterDirectory counterDirectory;
2468     MockBufferManager mockBuffer1(1024);
2469     SendCounterPacket sendCounterPacket(mockBuffer1,
2470                                         arm::pipe::ARMNN_SOFTWARE_INFO,
2471                                         arm::pipe::ARMNN_SOFTWARE_VERSION,
2472                                         arm::pipe::ARMNN_HARDWARE_VERSION);
2473     SendThread sendThread(profilingStateMachine, mockBuffer1, sendCounterPacket);
2474     MockBufferManager mockBuffer2(1024);
2475     SendTimelinePacket sendTimelinePacket(mockBuffer2);
2476     RequestCounterDirectoryCommandHandler commandHandler(familyId, packetId, version, counterDirectory,
2477                                                          sendCounterPacket, sendTimelinePacket, profilingStateMachine);
2478 
2479     const uint32_t wrongPacketId = 47;
2480     const uint32_t wrongHeader   = (wrongPacketId & 0x000003FF) << 16;
2481 
2482     arm::pipe::Packet wrongPacket(wrongHeader);
2483 
2484     profilingStateMachine.TransitionToState(ProfilingState::Uninitialised);
2485     CHECK_THROWS_AS(commandHandler(wrongPacket), arm::pipe::ProfilingException); // Wrong profiling state
2486     profilingStateMachine.TransitionToState(ProfilingState::NotConnected);
2487     CHECK_THROWS_AS(commandHandler(wrongPacket), arm::pipe::ProfilingException); // Wrong profiling state
2488     profilingStateMachine.TransitionToState(ProfilingState::WaitingForAck);
2489     CHECK_THROWS_AS(commandHandler(wrongPacket), arm::pipe::ProfilingException); // Wrong profiling state
2490     profilingStateMachine.TransitionToState(ProfilingState::Active);
2491     CHECK_THROWS_AS(commandHandler(wrongPacket), arm::pipe::InvalidArgumentException); // Wrong packet
2492 
2493     const uint32_t rightHeader = (packetId & 0x000003FF) << 16;
2494 
2495     arm::pipe::Packet rightPacket(rightHeader);
2496 
2497     CHECK_NOTHROW(commandHandler(rightPacket)); // Right packet
2498 
2499     auto readBuffer1 = mockBuffer1.GetReadableBuffer();
2500 
2501     uint32_t header1Word0 = ReadUint32(readBuffer1, 0);
2502     uint32_t header1Word1 = ReadUint32(readBuffer1, 4);
2503 
2504     // Counter directory packet
2505     CHECK(((header1Word0 >> 26) & 0x0000003F) == 0); // packet family
2506     CHECK(((header1Word0 >> 16) & 0x000003FF) == 2); // packet id
2507     CHECK(header1Word1 == 24);                       // data length
2508 
2509     uint32_t bodyHeader1Word0   = ReadUint32(readBuffer1, 8);
2510     uint16_t deviceRecordCount = arm::pipe::numeric_cast<uint16_t>(bodyHeader1Word0 >> 16);
2511     CHECK(deviceRecordCount == 0); // device_records_count
2512 
2513     auto readBuffer2 = mockBuffer2.GetReadableBuffer();
2514 
2515     uint32_t header2Word0 = ReadUint32(readBuffer2, 0);
2516     uint32_t header2Word1 = ReadUint32(readBuffer2, 4);
2517 
2518     // Timeline message directory packet
2519     CHECK(((header2Word0 >> 26) & 0x0000003F) == 1); // packet family
2520     CHECK(((header2Word0 >> 16) & 0x000003FF) == 0); // packet id
2521     CHECK(header2Word1 == 443);                      // data length
2522 }
2523 
2524 TEST_CASE("RequestCounterDirectoryCommandHandlerTest2")
2525 {
2526     const uint32_t familyId = 0;
2527     const uint32_t packetId = 3;
2528     const uint32_t version  = 1;
2529     ProfilingStateMachine profilingStateMachine;
2530     CounterDirectory counterDirectory;
2531     MockBufferManager mockBuffer1(1024);
2532     SendCounterPacket sendCounterPacket(mockBuffer1,
2533                                         arm::pipe::ARMNN_SOFTWARE_INFO,
2534                                         arm::pipe::ARMNN_SOFTWARE_VERSION,
2535                                         arm::pipe::ARMNN_HARDWARE_VERSION);
2536     SendThread sendThread(profilingStateMachine, mockBuffer1, sendCounterPacket);
2537     MockBufferManager mockBuffer2(1024);
2538     SendTimelinePacket sendTimelinePacket(mockBuffer2);
2539     RequestCounterDirectoryCommandHandler commandHandler(familyId, packetId, version, counterDirectory,
2540                                                          sendCounterPacket, sendTimelinePacket, profilingStateMachine);
2541     const uint32_t header = (packetId & 0x000003FF) << 16;
2542     const arm::pipe::Packet packet(header);
2543 
2544     const Device* device = counterDirectory.RegisterDevice("deviceA", 1);
2545     CHECK(device != nullptr);
2546     const CounterSet* counterSet = counterDirectory.RegisterCounterSet("countersetA");
2547     CHECK(counterSet != nullptr);
2548     counterDirectory.RegisterCategory("categoryA");
2549     counterDirectory.RegisterCounter(armnn::profiling::BACKEND_ID, 24,
2550                                      "categoryA", 0, 1, 2.0f, "counterA", "descA");
2551     counterDirectory.RegisterCounter(armnn::profiling::BACKEND_ID, 25,
2552                                      "categoryA", 1, 1, 3.0f, "counterB", "descB");
2553 
2554     profilingStateMachine.TransitionToState(ProfilingState::Uninitialised);
2555     CHECK_THROWS_AS(commandHandler(packet), arm::pipe::ProfilingException);    // Wrong profiling state
2556     profilingStateMachine.TransitionToState(ProfilingState::NotConnected);
2557     CHECK_THROWS_AS(commandHandler(packet), arm::pipe::ProfilingException);    // Wrong profiling state
2558     profilingStateMachine.TransitionToState(ProfilingState::WaitingForAck);
2559     CHECK_THROWS_AS(commandHandler(packet), arm::pipe::ProfilingException);    // Wrong profiling state
2560     profilingStateMachine.TransitionToState(ProfilingState::Active);
2561     CHECK_NOTHROW(commandHandler(packet));
2562 
2563     auto readBuffer1 = mockBuffer1.GetReadableBuffer();
2564 
2565     const uint32_t header1Word0 = ReadUint32(readBuffer1, 0);
2566     const uint32_t header1Word1 = ReadUint32(readBuffer1, 4);
2567 
2568     CHECK(((header1Word0 >> 26) & 0x0000003F) == 0); // packet family
2569     CHECK(((header1Word0 >> 16) & 0x000003FF) == 2); // packet id
2570     CHECK(header1Word1 == 236);                      // data length
2571 
2572     const uint32_t bodyHeaderSizeBytes = bodyHeaderSize * sizeof(uint32_t);
2573 
2574     const uint32_t bodyHeader1Word0      = ReadUint32(readBuffer1, 8);
2575     const uint32_t bodyHeader1Word1      = ReadUint32(readBuffer1, 12);
2576     const uint32_t bodyHeader1Word2      = ReadUint32(readBuffer1, 16);
2577     const uint32_t bodyHeader1Word3      = ReadUint32(readBuffer1, 20);
2578     const uint32_t bodyHeader1Word4      = ReadUint32(readBuffer1, 24);
2579     const uint32_t bodyHeader1Word5      = ReadUint32(readBuffer1, 28);
2580     const uint16_t deviceRecordCount     = arm::pipe::numeric_cast<uint16_t>(bodyHeader1Word0 >> 16);
2581     const uint16_t counterSetRecordCount = arm::pipe::numeric_cast<uint16_t>(bodyHeader1Word2 >> 16);
2582     const uint16_t categoryRecordCount   = arm::pipe::numeric_cast<uint16_t>(bodyHeader1Word4 >> 16);
2583     CHECK(deviceRecordCount == 1);                      // device_records_count
2584     CHECK(bodyHeader1Word1 == 0 + bodyHeaderSizeBytes);      // device_records_pointer_table_offset
2585     CHECK(counterSetRecordCount == 1);                  // counter_set_count
2586     CHECK(bodyHeader1Word3 == 4 + bodyHeaderSizeBytes);      // counter_set_pointer_table_offset
2587     CHECK(categoryRecordCount == 1);                    // categories_count
2588     CHECK(bodyHeader1Word5 == 8 + bodyHeaderSizeBytes);      // categories_pointer_table_offset
2589 
2590     const uint32_t deviceRecordOffset = ReadUint32(readBuffer1, 32);
2591     CHECK(deviceRecordOffset == 12);
2592 
2593     const uint32_t counterSetRecordOffset = ReadUint32(readBuffer1, 36);
2594     CHECK(counterSetRecordOffset == 28);
2595 
2596     const uint32_t categoryRecordOffset = ReadUint32(readBuffer1, 40);
2597     CHECK(categoryRecordOffset == 48);
2598 
2599     auto readBuffer2 = mockBuffer2.GetReadableBuffer();
2600 
2601     const uint32_t header2Word0 = ReadUint32(readBuffer2, 0);
2602     const uint32_t header2Word1 = ReadUint32(readBuffer2, 4);
2603 
2604     // Timeline message directory packet
2605     CHECK(((header2Word0 >> 26) & 0x0000003F) == 1); // packet family
2606     CHECK(((header2Word0 >> 16) & 0x000003FF) == 0); // packet id
2607     CHECK(header2Word1 == 443);                      // data length
2608 }
2609 
2610 TEST_CASE("CheckProfilingServiceGoodConnectionAcknowledgedPacket")
2611 {
2612     unsigned int streamMetadataPacketsize = GetStreamMetaDataPacketSize();
2613 
2614     // Reset the profiling service to the uninitialized state
2615     ProfilingOptions options;
2616     options.m_EnableProfiling          = true;
2617     armnn::ArmNNProfilingServiceInitialiser initialiser;
2618     ProfilingService profilingService(arm::pipe::MAX_ARMNN_COUNTER,
2619                                       initialiser,
2620                                       arm::pipe::ARMNN_SOFTWARE_INFO,
2621                                       arm::pipe::ARMNN_SOFTWARE_VERSION,
2622                                       arm::pipe::ARMNN_HARDWARE_VERSION);
2623     profilingService.ResetExternalProfilingOptions(options, true);
2624 
2625     // Swap the profiling connection factory in the profiling service instance with our mock one
2626     SwapProfilingConnectionFactoryHelper helper(arm::pipe::MAX_ARMNN_COUNTER, initialiser, profilingService);
2627 
2628     // Bring the profiling service to the "WaitingForAck" state
2629     CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised);
2630     profilingService.Update();    // Initialize the counter directory
2631     CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected);
2632     profilingService.Update();    // Create the profiling connection
2633 
2634     // Get the mock profiling connection
2635     MockProfilingConnection* mockProfilingConnection = helper.GetMockProfilingConnection();
2636     CHECK(mockProfilingConnection);
2637 
2638     // Remove the packets received so far
2639     mockProfilingConnection->Clear();
2640 
2641     CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck);
2642     profilingService.Update();    // Start the command handler and the send thread
2643 
2644     // Wait for the Stream Metadata packet to be sent
2645     CHECK(helper.WaitForPacketsSent(
2646             mockProfilingConnection, PacketType::StreamMetaData, streamMetadataPacketsize) >= 1);
2647 
2648     // Write a valid "Connection Acknowledged" packet into the mock profiling connection, to simulate a valid
2649     // reply from an external profiling service
2650 
2651     // Connection Acknowledged Packet header (word 0, word 1 is always zero):
2652     // 26:31 [6]  packet_family: Control Packet Family, value 0b000000
2653     // 16:25 [10] packet_id: Packet identifier, value 0b0000000001
2654     // 8:15  [8]  reserved: Reserved, value 0b00000000
2655     // 0:7   [8]  reserved: Reserved, value 0b00000000
2656     uint32_t packetFamily = 0;
2657     uint32_t packetId     = 1;
2658     uint32_t header       = ((packetFamily & 0x0000003F) << 26) | ((packetId & 0x000003FF) << 16);
2659 
2660     // Create the Connection Acknowledged Packet
2661     arm::pipe::Packet connectionAcknowledgedPacket(header);
2662 
2663     // Write the packet to the mock profiling connection
2664     mockProfilingConnection->WritePacket(std::move(connectionAcknowledgedPacket));
2665 
2666     // Wait for the counter directory packet to ensure the ConnectionAcknowledgedCommandHandler has run.
2667     CHECK(helper.WaitForPacketsSent(mockProfilingConnection, PacketType::CounterDirectory) == 1);
2668 
2669     // The Connection Acknowledged Command Handler should have updated the profiling state accordingly
2670     CHECK(profilingService.GetCurrentState() == ProfilingState::Active);
2671 
2672     // Reset the profiling service to stop any running thread
2673     options.m_EnableProfiling = false;
2674     profilingService.ResetExternalProfilingOptions(options, true);
2675 }
2676 
2677 TEST_CASE("CheckProfilingServiceGoodRequestCounterDirectoryPacket")
2678 {
2679     // Reset the profiling service to the uninitialized state
2680     ProfilingOptions options;
2681     options.m_EnableProfiling          = true;
2682     armnn::ArmNNProfilingServiceInitialiser initialiser;
2683     ProfilingService profilingService(arm::pipe::MAX_ARMNN_COUNTER,
2684                                       initialiser,
2685                                       arm::pipe::ARMNN_SOFTWARE_INFO,
2686                                       arm::pipe::ARMNN_SOFTWARE_VERSION,
2687                                       arm::pipe::ARMNN_HARDWARE_VERSION);
2688     profilingService.ResetExternalProfilingOptions(options, true);
2689 
2690     // Swap the profiling connection factory in the profiling service instance with our mock one
2691     SwapProfilingConnectionFactoryHelper helper(arm::pipe::MAX_ARMNN_COUNTER, initialiser, profilingService);
2692 
2693     // Bring the profiling service to the "Active" state
2694     CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised);
2695     profilingService.Update();    // Initialize the counter directory
2696     CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected);
2697     profilingService.Update();    // Create the profiling connection
2698     CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck);
2699     profilingService.Update();    // Start the command handler and the send thread
2700 
2701     // Get the mock profiling connection
2702     MockProfilingConnection* mockProfilingConnection = helper.GetMockProfilingConnection();
2703     CHECK(mockProfilingConnection);
2704 
2705     // Force the profiling service to the "Active" state
2706     helper.ForceTransitionToState(ProfilingState::Active);
2707     CHECK(profilingService.GetCurrentState() == ProfilingState::Active);
2708 
2709     // Write a valid "Request Counter Directory" packet into the mock profiling connection, to simulate a valid
2710     // reply from an external profiling service
2711 
2712     // Request Counter Directory packet header (word 0, word 1 is always zero):
2713     // 26:31 [6]  packet_family: Control Packet Family, value 0b000000
2714     // 16:25 [10] packet_id: Packet identifier, value 0b0000000011
2715     // 8:15  [8]  reserved: Reserved, value 0b00000000
2716     // 0:7   [8]  reserved: Reserved, value 0b00000000
2717     uint32_t packetFamily = 0;
2718     uint32_t packetId     = 3;
2719     uint32_t header       = ((packetFamily & 0x0000003F) << 26) | ((packetId & 0x000003FF) << 16);
2720 
2721     // Create the Request Counter Directory packet
2722     arm::pipe::Packet requestCounterDirectoryPacket(header);
2723 
2724     // Write the packet to the mock profiling connection
2725     mockProfilingConnection->WritePacket(std::move(requestCounterDirectoryPacket));
2726 
2727     // Expecting one CounterDirectory Packet of length 652
2728     // and one TimelineMessageDirectory packet of length 451
2729     CHECK(helper.WaitForPacketsSent(mockProfilingConnection, PacketType::CounterDirectory, 652) == 1);
2730     CHECK(helper.WaitForPacketsSent(mockProfilingConnection, PacketType::TimelineMessageDirectory, 451) == 1);
2731 
2732     // The Request Counter Directory Command Handler should not have updated the profiling state
2733     CHECK(profilingService.GetCurrentState() == ProfilingState::Active);
2734 
2735     // Reset the profiling service to stop any running thread
2736     options.m_EnableProfiling = false;
2737     profilingService.ResetExternalProfilingOptions(options, true);
2738 }
2739 
2740 TEST_CASE("CheckProfilingServiceBadPeriodicCounterSelectionPacketInvalidCounterUid")
2741 {
2742     // Reset the profiling service to the uninitialized state
2743     ProfilingOptions options;
2744     options.m_EnableProfiling          = true;
2745     armnn::ArmNNProfilingServiceInitialiser initialiser;
2746     ProfilingService profilingService(arm::pipe::MAX_ARMNN_COUNTER,
2747                                       initialiser,
2748                                       arm::pipe::ARMNN_SOFTWARE_INFO,
2749                                       arm::pipe::ARMNN_SOFTWARE_VERSION,
2750                                       arm::pipe::ARMNN_HARDWARE_VERSION);
2751     profilingService.ResetExternalProfilingOptions(options, true);
2752 
2753     // Swap the profiling connection factory in the profiling service instance with our mock one
2754     SwapProfilingConnectionFactoryHelper helper(arm::pipe::MAX_ARMNN_COUNTER, initialiser, profilingService);
2755 
2756     // Bring the profiling service to the "Active" state
2757     CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised);
2758     profilingService.Update();    // Initialize the counter directory
2759     CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected);
2760     profilingService.Update();    // Create the profiling connection
2761     CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck);
2762     profilingService.Update();    // Start the command handler and the send thread
2763 
2764     // Get the mock profiling connection
2765     MockProfilingConnection* mockProfilingConnection = helper.GetMockProfilingConnection();
2766     CHECK(mockProfilingConnection);
2767 
2768     // Force the profiling service to the "Active" state
2769     helper.ForceTransitionToState(ProfilingState::Active);
2770     CHECK(profilingService.GetCurrentState() == ProfilingState::Active);
2771 
2772     // Remove the packets received so far
2773     mockProfilingConnection->Clear();
2774 
2775     // Write a "Periodic Counter Selection" packet into the mock profiling connection, to simulate an input from an
2776     // external profiling service
2777 
2778     // Periodic Counter Selection packet header:
2779     // 26:31 [6]  packet_family: Control Packet Family, value 0b000000
2780     // 16:25 [10] packet_id: Packet identifier, value 0b0000000100
2781     // 8:15  [8]  reserved: Reserved, value 0b00000000
2782     // 0:7   [8]  reserved: Reserved, value 0b00000000
2783     uint32_t packetFamily = 0;
2784     uint32_t packetId     = 4;
2785     uint32_t header       = ((packetFamily & 0x0000003F) << 26) | ((packetId & 0x000003FF) << 16);
2786 
2787     uint32_t capturePeriod = 123456;    // Some capture period (microseconds)
2788 
2789     // Get the first valid counter UID
2790     const ICounterDirectory& counterDirectory = profilingService.GetCounterDirectory();
2791     const Counters& counters                  = counterDirectory.GetCounters();
2792     CHECK(counters.size() > 1);
2793     uint16_t counterUidA = counters.begin()->first;    // First valid counter UID
2794     uint16_t counterUidB = 9999;                       // Second invalid counter UID
2795 
2796     uint32_t length = 8;
2797 
2798     auto data = std::make_unique<unsigned char[]>(length);
2799     WriteUint32(data.get(), 0, capturePeriod);
2800     WriteUint16(data.get(), 4, counterUidA);
2801     WriteUint16(data.get(), 6, counterUidB);
2802 
2803     // Create the Periodic Counter Selection packet
2804     // Length > 0, this will start the Period Counter Capture thread
2805     arm::pipe::Packet periodicCounterSelectionPacket(header, length, data);
2806 
2807 
2808     // Write the packet to the mock profiling connection
2809     mockProfilingConnection->WritePacket(std::move(periodicCounterSelectionPacket));
2810 
2811     // Expecting one Periodic Counter Selection packet of length 14
2812     // and at least one Periodic Counter Capture packet of length 22
2813     CHECK(helper.WaitForPacketsSent(mockProfilingConnection, PacketType::PeriodicCounterSelection, 14) == 1);
2814     CHECK(helper.WaitForPacketsSent(mockProfilingConnection, PacketType::PeriodicCounterCapture, 22) >= 1);
2815 
2816     // The Periodic Counter Selection Handler should not have updated the profiling state
2817     CHECK(profilingService.GetCurrentState() == ProfilingState::Active);
2818 
2819     // Reset the profiling service to stop any running thread
2820     options.m_EnableProfiling = false;
2821     profilingService.ResetExternalProfilingOptions(options, true);
2822 }
2823 
2824 TEST_CASE("CheckProfilingServiceGoodPeriodicCounterSelectionPacketNoCounters")
2825 {
2826     // Reset the profiling service to the uninitialized state
2827     ProfilingOptions options;
2828     options.m_EnableProfiling          = true;
2829     armnn::ArmNNProfilingServiceInitialiser initialiser;
2830     ProfilingService profilingService(arm::pipe::MAX_ARMNN_COUNTER,
2831                                       initialiser,
2832                                       arm::pipe::ARMNN_SOFTWARE_INFO,
2833                                       arm::pipe::ARMNN_SOFTWARE_VERSION,
2834                                       arm::pipe::ARMNN_HARDWARE_VERSION);
2835     profilingService.ResetExternalProfilingOptions(options, true);
2836 
2837     // Swap the profiling connection factory in the profiling service instance with our mock one
2838     SwapProfilingConnectionFactoryHelper helper(arm::pipe::MAX_ARMNN_COUNTER, initialiser, profilingService);
2839 
2840     // Bring the profiling service to the "Active" state
2841     CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised);
2842     profilingService.Update();    // Initialize the counter directory
2843     CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected);
2844     profilingService.Update();    // Create the profiling connection
2845     CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck);
2846     profilingService.Update();    // Start the command handler and the send thread
2847 
2848     // Get the mock profiling connection
2849     MockProfilingConnection* mockProfilingConnection = helper.GetMockProfilingConnection();
2850     CHECK(mockProfilingConnection);
2851 
2852     // Wait for the Stream Metadata packet the be sent
2853     // (we are not testing the connection acknowledgement here so it will be ignored by this test)
2854     helper.WaitForPacketsSent(mockProfilingConnection, PacketType::StreamMetaData);
2855 
2856     // Force the profiling service to the "Active" state
2857     helper.ForceTransitionToState(ProfilingState::Active);
2858     CHECK(profilingService.GetCurrentState() == ProfilingState::Active);
2859 
2860     // Write a "Periodic Counter Selection" packet into the mock profiling connection, to simulate an input from an
2861     // external profiling service
2862 
2863     // Periodic Counter Selection packet header:
2864     // 26:31 [6]  packet_family: Control Packet Family, value 0b000000
2865     // 16:25 [10] packet_id: Packet identifier, value 0b0000000100
2866     // 8:15  [8]  reserved: Reserved, value 0b00000000
2867     // 0:7   [8]  reserved: Reserved, value 0b00000000
2868     uint32_t packetFamily = 0;
2869     uint32_t packetId     = 4;
2870     uint32_t header       = ((packetFamily & 0x0000003F) << 26) | ((packetId & 0x000003FF) << 16);
2871 
2872     // Create the Periodic Counter Selection packet
2873     // Length == 0, this will disable the collection of counters
2874     arm::pipe::Packet periodicCounterSelectionPacket(header);
2875 
2876     // Write the packet to the mock profiling connection
2877     mockProfilingConnection->WritePacket(std::move(periodicCounterSelectionPacket));
2878 
2879     // Wait for the Periodic Counter Selection packet of length 12 to be sent
2880     // The size of the expected Periodic Counter Selection (echos the sent one)
2881     CHECK(helper.WaitForPacketsSent(mockProfilingConnection, PacketType::PeriodicCounterSelection, 12) == 1);
2882 
2883     // The Periodic Counter Selection Handler should not have updated the profiling state
2884     CHECK(profilingService.GetCurrentState() == ProfilingState::Active);
2885 
2886     // No Periodic Counter packets are expected
2887     CHECK(helper.WaitForPacketsSent(mockProfilingConnection, PacketType::PeriodicCounterCapture, 0, 0) == 0);
2888 
2889     // Reset the profiling service to stop any running thread
2890     options.m_EnableProfiling = false;
2891     profilingService.ResetExternalProfilingOptions(options, true);
2892 }
2893 
2894 TEST_CASE("CheckProfilingServiceGoodPeriodicCounterSelectionPacketSingleCounter")
2895 {
2896     // Reset the profiling service to the uninitialized state
2897     ProfilingOptions options;
2898     options.m_EnableProfiling          = true;
2899     armnn::ArmNNProfilingServiceInitialiser initialiser;
2900     ProfilingService profilingService(arm::pipe::MAX_ARMNN_COUNTER,
2901                                       initialiser,
2902                                       arm::pipe::ARMNN_SOFTWARE_INFO,
2903                                       arm::pipe::ARMNN_SOFTWARE_VERSION,
2904                                       arm::pipe::ARMNN_HARDWARE_VERSION);
2905     profilingService.ResetExternalProfilingOptions(options, true);
2906 
2907     // Swap the profiling connection factory in the profiling service instance with our mock one
2908     SwapProfilingConnectionFactoryHelper helper(arm::pipe::MAX_ARMNN_COUNTER, initialiser, profilingService);
2909 
2910     // Bring the profiling service to the "Active" state
2911     CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised);
2912     profilingService.Update();    // Initialize the counter directory
2913     CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected);
2914     profilingService.Update();    // Create the profiling connection
2915     CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck);
2916     profilingService.Update();    // Start the command handler and the send thread
2917 
2918     // Get the mock profiling connection
2919     MockProfilingConnection* mockProfilingConnection = helper.GetMockProfilingConnection();
2920     CHECK(mockProfilingConnection);
2921 
2922     // Wait for the Stream Metadata packet to be sent
2923     // (we are not testing the connection acknowledgement here so it will be ignored by this test)
2924     helper.WaitForPacketsSent(mockProfilingConnection, PacketType::StreamMetaData);
2925 
2926     // Force the profiling service to the "Active" state
2927     helper.ForceTransitionToState(ProfilingState::Active);
2928     CHECK(profilingService.GetCurrentState() == ProfilingState::Active);
2929 
2930     // Write a "Periodic Counter Selection" packet into the mock profiling connection, to simulate an input from an
2931     // external profiling service
2932 
2933     // Periodic Counter Selection packet header:
2934     // 26:31 [6]  packet_family: Control Packet Family, value 0b000000
2935     // 16:25 [10] packet_id: Packet identifier, value 0b0000000100
2936     // 8:15  [8]  reserved: Reserved, value 0b00000000
2937     // 0:7   [8]  reserved: Reserved, value 0b00000000
2938     uint32_t packetFamily = 0;
2939     uint32_t packetId     = 4;
2940     uint32_t header       = ((packetFamily & 0x0000003F) << 26) | ((packetId & 0x000003FF) << 16);
2941 
2942     uint32_t capturePeriod = 123456;    // Some capture period (microseconds)
2943 
2944     // Get the first valid counter UID
2945     const ICounterDirectory& counterDirectory = profilingService.GetCounterDirectory();
2946     const Counters& counters                  = counterDirectory.GetCounters();
2947     CHECK(!counters.empty());
2948     uint16_t counterUid = counters.begin()->first;    // Valid counter UID
2949 
2950     uint32_t length = 6;
2951 
2952     auto data = std::make_unique<unsigned char[]>(length);
2953     WriteUint32(data.get(), 0, capturePeriod);
2954     WriteUint16(data.get(), 4, counterUid);
2955 
2956     // Create the Periodic Counter Selection packet
2957     // Length > 0, this will start the Period Counter Capture thread
2958     arm::pipe::Packet periodicCounterSelectionPacket(header, length, data);
2959 
2960     // Write the packet to the mock profiling connection
2961     mockProfilingConnection->WritePacket(std::move(periodicCounterSelectionPacket));
2962 
2963     // Expecting one Periodic Counter Selection packet of length 14
2964     // and at least one Periodic Counter Capture packet of length 22
2965     CHECK(helper.WaitForPacketsSent(mockProfilingConnection, PacketType::PeriodicCounterSelection, 14) == 1);
2966     CHECK(helper.WaitForPacketsSent(mockProfilingConnection, PacketType::PeriodicCounterCapture, 22) >= 1);
2967 
2968     // The Periodic Counter Selection Handler should not have updated the profiling state
2969     CHECK(profilingService.GetCurrentState() == ProfilingState::Active);
2970 
2971     // Reset the profiling service to stop any running thread
2972     options.m_EnableProfiling = false;
2973     profilingService.ResetExternalProfilingOptions(options, true);
2974 }
2975 
2976 TEST_CASE("CheckProfilingServiceGoodPeriodicCounterSelectionPacketMultipleCounters")
2977 {
2978     // Reset the profiling service to the uninitialized state
2979     ProfilingOptions options;
2980     options.m_EnableProfiling          = true;
2981     armnn::ArmNNProfilingServiceInitialiser initialiser;
2982     ProfilingService profilingService(arm::pipe::MAX_ARMNN_COUNTER,
2983                                       initialiser,
2984                                       arm::pipe::ARMNN_SOFTWARE_INFO,
2985                                       arm::pipe::ARMNN_SOFTWARE_VERSION,
2986                                       arm::pipe::ARMNN_HARDWARE_VERSION);
2987     profilingService.ResetExternalProfilingOptions(options, true);
2988 
2989     // Swap the profiling connection factory in the profiling service instance with our mock one
2990     SwapProfilingConnectionFactoryHelper helper(arm::pipe::MAX_ARMNN_COUNTER, initialiser, profilingService);
2991 
2992     // Bring the profiling service to the "Active" state
2993     CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised);
2994     profilingService.Update();    // Initialize the counter directory
2995     CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected);
2996     profilingService.Update();    // Create the profiling connection
2997     CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck);
2998     profilingService.Update();    // Start the command handler and the send thread
2999 
3000     // Get the mock profiling connection
3001     MockProfilingConnection* mockProfilingConnection = helper.GetMockProfilingConnection();
3002     CHECK(mockProfilingConnection);
3003 
3004     // Wait for the Stream Metadata packet the be sent
3005     // (we are not testing the connection acknowledgement here so it will be ignored by this test)
3006     helper.WaitForPacketsSent(mockProfilingConnection, PacketType::StreamMetaData);
3007 
3008     // Force the profiling service to the "Active" state
3009     helper.ForceTransitionToState(ProfilingState::Active);
3010     CHECK(profilingService.GetCurrentState() == ProfilingState::Active);
3011 
3012     // Write a "Periodic Counter Selection" packet into the mock profiling connection, to simulate an input from an
3013     // external profiling service
3014 
3015     // Periodic Counter Selection packet header:
3016     // 26:31 [6]  packet_family: Control Packet Family, value 0b000000
3017     // 16:25 [10] packet_id: Packet identifier, value 0b0000000100
3018     // 8:15  [8]  reserved: Reserved, value 0b00000000
3019     // 0:7   [8]  reserved: Reserved, value 0b00000000
3020     uint32_t packetFamily = 0;
3021     uint32_t packetId     = 4;
3022     uint32_t header       = ((packetFamily & 0x0000003F) << 26) | ((packetId & 0x000003FF) << 16);
3023 
3024     uint32_t capturePeriod = 123456;    // Some capture period (microseconds)
3025 
3026     // Get the first valid counter UID
3027     const ICounterDirectory& counterDirectory = profilingService.GetCounterDirectory();
3028     const Counters& counters                  = counterDirectory.GetCounters();
3029     CHECK(counters.size() > 1);
3030     uint16_t counterUidA = counters.begin()->first;        // First valid counter UID
3031     uint16_t counterUidB = (counters.begin()++)->first;    // Second valid counter UID
3032 
3033     uint32_t length = 8;
3034 
3035     auto data = std::make_unique<unsigned char[]>(length);
3036     WriteUint32(data.get(), 0, capturePeriod);
3037     WriteUint16(data.get(), 4, counterUidA);
3038     WriteUint16(data.get(), 6, counterUidB);
3039 
3040     // Create the Periodic Counter Selection packet
3041     // Length > 0, this will start the Period Counter Capture thread
3042     arm::pipe::Packet periodicCounterSelectionPacket(header, length, data);
3043 
3044     // Write the packet to the mock profiling connection
3045     mockProfilingConnection->WritePacket(std::move(periodicCounterSelectionPacket));
3046 
3047     // Expecting one PeriodicCounterSelection Packet with a length of 16
3048     // And at least one PeriodicCounterCapture Packet with a length of 28
3049     CHECK(helper.WaitForPacketsSent(mockProfilingConnection, PacketType::PeriodicCounterSelection, 16) == 1);
3050     CHECK(helper.WaitForPacketsSent(mockProfilingConnection, PacketType::PeriodicCounterCapture, 28) >= 1);
3051 
3052     // The Periodic Counter Selection Handler should not have updated the profiling state
3053     CHECK(profilingService.GetCurrentState() == ProfilingState::Active);
3054 
3055     // Reset the profiling service to stop any running thread
3056     options.m_EnableProfiling = false;
3057     profilingService.ResetExternalProfilingOptions(options, true);
3058 }
3059 
3060 TEST_CASE("CheckProfilingServiceDisconnect")
3061 {
3062     // Reset the profiling service to the uninitialized state
3063     ProfilingOptions options;
3064     options.m_EnableProfiling          = true;
3065     armnn::ArmNNProfilingServiceInitialiser initialiser;
3066     ProfilingService profilingService(arm::pipe::MAX_ARMNN_COUNTER,
3067                                       initialiser,
3068                                       arm::pipe::ARMNN_SOFTWARE_INFO,
3069                                       arm::pipe::ARMNN_SOFTWARE_VERSION,
3070                                       arm::pipe::ARMNN_HARDWARE_VERSION);
3071     profilingService.ResetExternalProfilingOptions(options, true);
3072 
3073     // Swap the profiling connection factory in the profiling service instance with our mock one
3074     SwapProfilingConnectionFactoryHelper helper(arm::pipe::MAX_ARMNN_COUNTER, initialiser, profilingService);
3075 
3076     // Try to disconnect the profiling service while in the "Uninitialised" state
3077     CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised);
3078     profilingService.Disconnect();
3079     CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised);    // The state should not change
3080 
3081     // Try to disconnect the profiling service while in the "NotConnected" state
3082     profilingService.Update();    // Initialize the counter directory
3083     CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected);
3084     profilingService.Disconnect();
3085     CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected);    // The state should not change
3086 
3087     // Try to disconnect the profiling service while in the "WaitingForAck" state
3088     profilingService.Update();    // Create the profiling connection
3089     CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck);
3090     profilingService.Disconnect();
3091     CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck);    // The state should not change
3092 
3093     // Try to disconnect the profiling service while in the "Active" state
3094     profilingService.Update();    // Start the command handler and the send thread
3095 
3096     // Get the mock profiling connection
3097     MockProfilingConnection* mockProfilingConnection = helper.GetMockProfilingConnection();
3098     CHECK(mockProfilingConnection);
3099 
3100     // Wait for the Stream Metadata packet the be sent
3101     // (we are not testing the connection acknowledgement here so it will be ignored by this test)
3102     helper.WaitForPacketsSent(mockProfilingConnection, PacketType::StreamMetaData);
3103 
3104     // Force the profiling service to the "Active" state
3105     helper.ForceTransitionToState(ProfilingState::Active);
3106     CHECK(profilingService.GetCurrentState() == ProfilingState::Active);
3107 
3108     // Check that the profiling connection is open
3109     CHECK(mockProfilingConnection->IsOpen());
3110 
3111     profilingService.Disconnect();
3112     CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected);   // The state should have changed
3113 
3114     // Check that the profiling connection has been reset
3115     mockProfilingConnection = helper.GetMockProfilingConnection();
3116     CHECK(mockProfilingConnection == nullptr);
3117 
3118     // Reset the profiling service to stop any running thread
3119     options.m_EnableProfiling = false;
3120     profilingService.ResetExternalProfilingOptions(options, true);
3121 }
3122 
3123 TEST_CASE("CheckProfilingServiceGoodPerJobCounterSelectionPacket")
3124 {
3125     // Reset the profiling service to the uninitialized state
3126     ProfilingOptions options;
3127     options.m_EnableProfiling          = true;
3128     armnn::ArmNNProfilingServiceInitialiser initialiser;
3129     ProfilingService profilingService(arm::pipe::MAX_ARMNN_COUNTER,
3130                                       initialiser,
3131                                       arm::pipe::ARMNN_SOFTWARE_INFO,
3132                                       arm::pipe::ARMNN_SOFTWARE_VERSION,
3133                                       arm::pipe::ARMNN_HARDWARE_VERSION);
3134     profilingService.ResetExternalProfilingOptions(options, true);
3135 
3136     // Swap the profiling connection factory in the profiling service instance with our mock one
3137     SwapProfilingConnectionFactoryHelper helper(arm::pipe::MAX_ARMNN_COUNTER, initialiser, profilingService);
3138 
3139     // Bring the profiling service to the "Active" state
3140     CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised);
3141     profilingService.Update();    // Initialize the counter directory
3142     CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected);
3143     profilingService.Update();    // Create the profiling connection
3144     CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck);
3145     profilingService.Update();    // Start the command handler and the send thread
3146 
3147     // Get the mock profiling connection
3148     MockProfilingConnection* mockProfilingConnection = helper.GetMockProfilingConnection();
3149     CHECK(mockProfilingConnection);
3150 
3151     // Wait for the Stream Metadata packet the be sent
3152     // (we are not testing the connection acknowledgement here so it will be ignored by this test)
3153     helper.WaitForPacketsSent(mockProfilingConnection, PacketType::StreamMetaData);
3154 
3155     // Force the profiling service to the "Active" state
3156     helper.ForceTransitionToState(ProfilingState::Active);
3157     CHECK(profilingService.GetCurrentState() == ProfilingState::Active);
3158 
3159     // Write a "Per-Job Counter Selection" packet into the mock profiling connection, to simulate an input from an
3160     // external profiling service
3161 
3162     // Per-Job Counter Selection packet header:
3163     // 26:31 [6]  packet_family: Control Packet Family, value 0b000000
3164     // 16:25 [10] packet_id: Packet identifier, value 0b0000000100
3165     // 8:15  [8]  reserved: Reserved, value 0b00000000
3166     // 0:7   [8]  reserved: Reserved, value 0b00000000
3167     uint32_t packetFamily = 0;
3168     uint32_t packetId     = 5;
3169     uint32_t header       = ((packetFamily & 0x0000003F) << 26) | ((packetId & 0x000003FF) << 16);
3170 
3171     // Create the Per-Job Counter Selection packet
3172     // Length == 0, this will disable the collection of counters
3173     arm::pipe::Packet periodicCounterSelectionPacket(header);
3174 
3175     // Write the packet to the mock profiling connection
3176     mockProfilingConnection->WritePacket(std::move(periodicCounterSelectionPacket));
3177 
3178     // Wait for a bit (must at least be the delay value of the mock profiling connection) to make sure that
3179     // the Per-Job Counter Selection packet gets processed by the profiling service
3180     std::this_thread::sleep_for(std::chrono::milliseconds(5));
3181 
3182     // The Per-Job Counter Selection Command Handler should not have updated the profiling state
3183     CHECK(profilingService.GetCurrentState() == ProfilingState::Active);
3184 
3185     // The Per-Job Counter Selection packets are dropped silently, so there should be no reply coming
3186     // from the profiling service
3187     const auto StreamMetaDataSize = static_cast<unsigned long>(
3188             helper.WaitForPacketsSent(mockProfilingConnection, PacketType::StreamMetaData, 0, 0));
3189     CHECK(StreamMetaDataSize == mockProfilingConnection->GetWrittenDataSize());
3190 
3191     // Reset the profiling service to stop any running thread
3192     options.m_EnableProfiling = false;
3193     profilingService.ResetExternalProfilingOptions(options, true);
3194 }
3195 
3196 TEST_CASE("CheckConfigureProfilingServiceOn")
3197 {
3198     ProfilingOptions options;
3199     options.m_EnableProfiling          = true;
3200     armnn::ArmNNProfilingServiceInitialiser initialiser;
3201     ProfilingService profilingService(arm::pipe::MAX_ARMNN_COUNTER,
3202                                       initialiser,
3203                                       arm::pipe::ARMNN_SOFTWARE_INFO,
3204                                       arm::pipe::ARMNN_SOFTWARE_VERSION,
3205                                       arm::pipe::ARMNN_HARDWARE_VERSION);
3206     CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised);
3207     profilingService.ConfigureProfilingService(options);
3208     // should get as far as NOT_CONNECTED
3209     CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected);
3210     // Reset the profiling service to stop any running thread
3211     options.m_EnableProfiling = false;
3212     profilingService.ResetExternalProfilingOptions(options, true);
3213 }
3214 
3215 TEST_CASE("CheckConfigureProfilingServiceOff")
3216 {
3217     ProfilingOptions options;
3218     armnn::ArmNNProfilingServiceInitialiser initialiser;
3219     ProfilingService profilingService(arm::pipe::MAX_ARMNN_COUNTER,
3220                                       initialiser,
3221                                       arm::pipe::ARMNN_SOFTWARE_INFO,
3222                                       arm::pipe::ARMNN_SOFTWARE_VERSION,
3223                                       arm::pipe::ARMNN_HARDWARE_VERSION);
3224     CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised);
3225     profilingService.ConfigureProfilingService(options);
3226     // should not move from Uninitialised
3227     CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised);
3228     // Reset the profiling service to stop any running thread
3229     options.m_EnableProfiling = false;
3230     profilingService.ResetExternalProfilingOptions(options, true);
3231 }
3232 
3233 TEST_CASE("CheckProfilingServiceEnabled")
3234 {
3235     // Locally reduce log level to "Warning", as this test needs to parse a warning message from the standard output
3236     LogLevelSwapper logLevelSwapper(arm::pipe::LogSeverity::Warning);
3237 
3238     // Redirect the output to a local stream so that we can parse the warning message
3239     std::stringstream ss;
3240     StreamRedirector streamRedirector(std::cout, ss.rdbuf());
3241 
3242     ProfilingOptions options;
3243     options.m_EnableProfiling          = true;
3244     armnn::ArmNNProfilingServiceInitialiser initialiser;
3245     ProfilingService profilingService(arm::pipe::MAX_ARMNN_COUNTER,
3246                                       initialiser,
3247                                       arm::pipe::ARMNN_SOFTWARE_INFO,
3248                                       arm::pipe::ARMNN_SOFTWARE_VERSION,
3249                                       arm::pipe::ARMNN_HARDWARE_VERSION);
3250     profilingService.ResetExternalProfilingOptions(options, true);
3251     CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised);
3252     profilingService.Update();
3253     CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected);
3254 
3255     profilingService.Update();
3256 
3257     // Reset the profiling service to stop any running thread
3258     options.m_EnableProfiling = false;
3259     profilingService.ResetExternalProfilingOptions(options, true);
3260 
3261     streamRedirector.CancelRedirect();
3262 
3263     // Check that the expected error has occurred and logged to the standard output
3264     if (ss.str().find("Cannot connect to stream socket: Connection refused") == std::string::npos)
3265     {
3266         std::cout << ss.str();
3267         FAIL("Expected string not found.");
3268     }
3269 }
3270 
3271 TEST_CASE("CheckProfilingServiceEnabledRuntime")
3272 {
3273     // Locally reduce log level to "Warning", as this test needs to parse a warning message from the standard output
3274     LogLevelSwapper logLevelSwapper(arm::pipe::LogSeverity::Warning);
3275 
3276     // Redirect the output to a local stream so that we can parse the warning message
3277     std::stringstream ss;
3278     StreamRedirector streamRedirector(std::cout, ss.rdbuf());
3279 
3280     ProfilingOptions options;
3281     armnn::ArmNNProfilingServiceInitialiser initialiser;
3282     ProfilingService profilingService(arm::pipe::MAX_ARMNN_COUNTER,
3283                                       initialiser,
3284                                       arm::pipe::ARMNN_SOFTWARE_INFO,
3285                                       arm::pipe::ARMNN_SOFTWARE_VERSION,
3286                                       arm::pipe::ARMNN_HARDWARE_VERSION);
3287     profilingService.ResetExternalProfilingOptions(options, true);
3288     CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised);
3289     profilingService.Update();
3290     CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised);
3291     options.m_EnableProfiling = true;
3292     profilingService.ResetExternalProfilingOptions(options);
3293     CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised);
3294     profilingService.Update();
3295     CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected);
3296 
3297     profilingService.Update();
3298 
3299     // Reset the profiling service to stop any running thread
3300     options.m_EnableProfiling = false;
3301     profilingService.ResetExternalProfilingOptions(options, true);
3302 
3303     streamRedirector.CancelRedirect();
3304 
3305     // Check that the expected error has occurred and logged to the standard output
3306     if (ss.str().find("Cannot connect to stream socket: Connection refused") == std::string::npos)
3307     {
3308         std::cout << ss.str();
3309         FAIL("Expected string not found.");
3310     }
3311 }
3312 
3313 TEST_CASE("CheckProfilingServiceBadConnectionAcknowledgedPacket")
3314 {
3315     // Locally reduce log level to "Warning", as this test needs to parse a warning message from the standard output
3316     LogLevelSwapper logLevelSwapper(arm::pipe::LogSeverity::Warning);
3317 
3318     // Redirect the standard output to a local stream so that we can parse the warning message
3319     std::stringstream ss;
3320     StreamRedirector streamRedirector(std::cout, ss.rdbuf());
3321 
3322     // Reset the profiling service to the uninitialized state
3323     ProfilingOptions options;
3324     options.m_EnableProfiling          = true;
3325     armnn::ArmNNProfilingServiceInitialiser initialiser;
3326     ProfilingService profilingService(arm::pipe::MAX_ARMNN_COUNTER,
3327                                       initialiser,
3328                                       arm::pipe::ARMNN_SOFTWARE_INFO,
3329                                       arm::pipe::ARMNN_SOFTWARE_VERSION,
3330                                       arm::pipe::ARMNN_HARDWARE_VERSION);
3331     profilingService.ResetExternalProfilingOptions(options, true);
3332 
3333     // Swap the profiling connection factory in the profiling service instance with our mock one
3334     SwapProfilingConnectionFactoryHelper helper(arm::pipe::MAX_ARMNN_COUNTER, initialiser, profilingService);
3335 
3336     // Bring the profiling service to the "WaitingForAck" state
3337     CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised);
3338     profilingService.Update();    // Initialize the counter directory
3339     CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected);
3340     profilingService.Update();    // Create the profiling connection
3341 
3342     // Get the mock profiling connection
3343     MockProfilingConnection* mockProfilingConnection = helper.GetMockProfilingConnection();
3344     CHECK(mockProfilingConnection);
3345 
3346     CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck);
3347 
3348     // Connection Acknowledged Packet header (word 0, word 1 is always zero):
3349     // 26:31 [6]  packet_family: Control Packet Family, value 0b000000
3350     // 16:25 [10] packet_id: Packet identifier, value 0b0000000001
3351     // 8:15  [8]  reserved: Reserved, value 0b00000000
3352     // 0:7   [8]  reserved: Reserved, value 0b00000000
3353     uint32_t packetFamily = 0;
3354     uint32_t packetId     = 37;    // Wrong packet id!!!
3355     uint32_t header       = ((packetFamily & 0x0000003F) << 26) | ((packetId & 0x000003FF) << 16);
3356 
3357     // Create the Connection Acknowledged Packet
3358     arm::pipe::Packet connectionAcknowledgedPacket(header);
3359     // Write an invalid "Connection Acknowledged" packet into the mock profiling connection, to simulate an invalid
3360     // reply from an external profiling service
3361     mockProfilingConnection->WritePacket(std::move(connectionAcknowledgedPacket));
3362 
3363     // Start the command thread
3364     profilingService.Update();
3365 
3366     // Wait for the command thread to join
3367     options.m_EnableProfiling = false;
3368     profilingService.ResetExternalProfilingOptions(options, true);
3369 
3370     streamRedirector.CancelRedirect();
3371 
3372     // Check that the expected error has occurred and logged to the standard output
3373     if (ss.str().find("Functor with requested PacketId=37 and Version=4194304 does not exist") == std::string::npos)
3374     {
3375         std::cout << ss.str();
3376         FAIL("Expected string not found.");
3377     }
3378 }
3379 
3380 TEST_CASE("CheckProfilingServiceBadRequestCounterDirectoryPacket")
3381 {
3382     // Locally reduce log level to "Warning", as this test needs to parse a warning message from the standard output
3383     LogLevelSwapper logLevelSwapper(arm::pipe::LogSeverity::Warning);
3384 
3385     // Redirect the standard output to a local stream so that we can parse the warning message
3386     std::stringstream ss;
3387     StreamRedirector streamRedirector(std::cout, ss.rdbuf());
3388 
3389     // Reset the profiling service to the uninitialized state
3390     ProfilingOptions options;
3391     options.m_EnableProfiling          = true;
3392     armnn::ArmNNProfilingServiceInitialiser initialiser;
3393     ProfilingService profilingService(arm::pipe::MAX_ARMNN_COUNTER,
3394                                       initialiser,
3395                                       arm::pipe::ARMNN_SOFTWARE_INFO,
3396                                       arm::pipe::ARMNN_SOFTWARE_VERSION,
3397                                       arm::pipe::ARMNN_HARDWARE_VERSION);
3398     profilingService.ResetExternalProfilingOptions(options, true);
3399 
3400     // Swap the profiling connection factory in the profiling service instance with our mock one
3401     SwapProfilingConnectionFactoryHelper helper(arm::pipe::MAX_ARMNN_COUNTER, initialiser, profilingService);
3402 
3403     // Bring the profiling service to the "Active" state
3404     CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised);
3405     helper.ForceTransitionToState(ProfilingState::NotConnected);
3406     CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected);
3407     profilingService.Update();    // Create the profiling connection
3408     CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck);
3409 
3410     // Get the mock profiling connection
3411     MockProfilingConnection* mockProfilingConnection = helper.GetMockProfilingConnection();
3412     CHECK(mockProfilingConnection);
3413 
3414     // Write a valid "Request Counter Directory" packet into the mock profiling connection, to simulate a valid
3415     // reply from an external profiling service
3416 
3417     // Request Counter Directory packet header (word 0, word 1 is always zero):
3418     // 26:31 [6]  packet_family: Control Packet Family, value 0b000000
3419     // 16:25 [10] packet_id: Packet identifier, value 0b0000000011
3420     // 8:15  [8]  reserved: Reserved, value 0b00000000
3421     // 0:7   [8]  reserved: Reserved, value 0b00000000
3422     uint32_t packetFamily = 0;
3423     uint32_t packetId     = 123;    // Wrong packet id!!!
3424     uint32_t header       = ((packetFamily & 0x0000003F) << 26) | ((packetId & 0x000003FF) << 16);
3425 
3426     // Create the Request Counter Directory packet
3427     arm::pipe::Packet requestCounterDirectoryPacket(header);
3428 
3429     // Write the packet to the mock profiling connection
3430     mockProfilingConnection->WritePacket(std::move(requestCounterDirectoryPacket));
3431 
3432     // Start the command handler and the send thread
3433     profilingService.Update();
3434 
3435     // Reset the profiling service to stop and join any running thread
3436     options.m_EnableProfiling = false;
3437     profilingService.ResetExternalProfilingOptions(options, true);
3438 
3439     streamRedirector.CancelRedirect();
3440 
3441     // Check that the expected error has occurred and logged to the standard output
3442     if (ss.str().find("Functor with requested PacketId=123 and Version=4194304 does not exist") == std::string::npos)
3443     {
3444         std::cout << ss.str();
3445         FAIL("Expected string not found.");
3446     }
3447 }
3448 
3449 TEST_CASE("CheckProfilingServiceBadPeriodicCounterSelectionPacket")
3450 {
3451     // Locally reduce log level to "Warning", as this test needs to parse a warning message from the standard output
3452     LogLevelSwapper logLevelSwapper(arm::pipe::LogSeverity::Warning);
3453 
3454     // Redirect the standard output to a local stream so that we can parse the warning message
3455     std::stringstream ss;
3456     StreamRedirector streamRedirector(std::cout, ss.rdbuf());
3457 
3458     // Reset the profiling service to the uninitialized state
3459     ProfilingOptions options;
3460     options.m_EnableProfiling          = true;
3461     armnn::ArmNNProfilingServiceInitialiser initialiser;
3462     ProfilingService profilingService(arm::pipe::MAX_ARMNN_COUNTER,
3463                                       initialiser,
3464                                       arm::pipe::ARMNN_SOFTWARE_INFO,
3465                                       arm::pipe::ARMNN_SOFTWARE_VERSION,
3466                                       arm::pipe::ARMNN_HARDWARE_VERSION);
3467     profilingService.ResetExternalProfilingOptions(options, true);
3468 
3469     // Swap the profiling connection factory in the profiling service instance with our mock one
3470     SwapProfilingConnectionFactoryHelper helper(arm::pipe::MAX_ARMNN_COUNTER, initialiser, profilingService);
3471 
3472     // Bring the profiling service to the "Active" state
3473     CHECK(profilingService.GetCurrentState() == ProfilingState::Uninitialised);
3474     profilingService.Update();    // Initialize the counter directory
3475     CHECK(profilingService.GetCurrentState() == ProfilingState::NotConnected);
3476     profilingService.Update();    // Create the profiling connection
3477     CHECK(profilingService.GetCurrentState() == ProfilingState::WaitingForAck);
3478     profilingService.Update();    // Start the command handler and the send thread
3479 
3480     // Get the mock profiling connection
3481     MockProfilingConnection* mockProfilingConnection = helper.GetMockProfilingConnection();
3482     CHECK(mockProfilingConnection);
3483 
3484     // Write a "Periodic Counter Selection" packet into the mock profiling connection, to simulate an input from an
3485     // external profiling service
3486 
3487     // Periodic Counter Selection packet header:
3488     // 26:31 [6]  packet_family: Control Packet Family, value 0b000000
3489     // 16:25 [10] packet_id: Packet identifier, value 0b0000000100
3490     // 8:15  [8]  reserved: Reserved, value 0b00000000
3491     // 0:7   [8]  reserved: Reserved, value 0b00000000
3492     uint32_t packetFamily = 0;
3493     uint32_t packetId     = 999;    // Wrong packet id!!!
3494     uint32_t header       = ((packetFamily & 0x0000003F) << 26) | ((packetId & 0x000003FF) << 16);
3495 
3496     // Create the Periodic Counter Selection packet
3497     // Length == 0, this will disable the collection of counters
3498     arm::pipe::Packet periodicCounterSelectionPacket(header);
3499 
3500     // Write the packet to the mock profiling connection
3501     mockProfilingConnection->WritePacket(std::move(periodicCounterSelectionPacket));
3502     profilingService.Update();
3503 
3504     // Reset the profiling service to stop any running thread
3505     options.m_EnableProfiling = false;
3506     profilingService.ResetExternalProfilingOptions(options, true);
3507 
3508     // Check that the expected error has occurred and logged to the standard output
3509     streamRedirector.CancelRedirect();
3510 
3511     // Check that the expected error has occurred and logged to the standard output
3512     if (ss.str().find("Functor with requested PacketId=999 and Version=4194304 does not exist") == std::string::npos)
3513     {
3514         std::cout << ss.str();
3515         FAIL("Expected string not found.");
3516     }
3517 }
3518 
3519 TEST_CASE("CheckCounterIdMap")
3520 {
3521     CounterIdMap counterIdMap;
3522     CHECK_THROWS_AS(counterIdMap.GetBackendId(0), arm::pipe::ProfilingException);
3523     CHECK_THROWS_AS(counterIdMap.GetGlobalId(0, armnn::profiling::BACKEND_ID), arm::pipe::ProfilingException);
3524 
3525     uint16_t globalCounterIds = 0;
3526 
3527     std::string cpuRefId(GetComputeDeviceAsCString(armnn::Compute::CpuRef));
3528     std::string cpuAccId(GetComputeDeviceAsCString(armnn::Compute::CpuAcc));
3529 
3530     std::vector<uint16_t> cpuRefCounters = {0, 1, 2, 3};
3531     std::vector<uint16_t> cpuAccCounters = {0, 1};
3532 
3533     for (uint16_t backendCounterId : cpuRefCounters)
3534     {
3535         counterIdMap.RegisterMapping(globalCounterIds, backendCounterId, cpuRefId);
3536         ++globalCounterIds;
3537     }
3538     for (uint16_t backendCounterId : cpuAccCounters)
3539     {
3540         counterIdMap.RegisterMapping(globalCounterIds, backendCounterId, cpuAccId);
3541         ++globalCounterIds;
3542     }
3543 
3544     CHECK(counterIdMap.GetBackendId(0) == (std::pair<uint16_t, std::string>(0, cpuRefId)));
3545     CHECK(counterIdMap.GetBackendId(1) == (std::pair<uint16_t, std::string>(1, cpuRefId)));
3546     CHECK(counterIdMap.GetBackendId(2) == (std::pair<uint16_t, std::string>(2, cpuRefId)));
3547     CHECK(counterIdMap.GetBackendId(3) == (std::pair<uint16_t, std::string>(3, cpuRefId)));
3548     CHECK(counterIdMap.GetBackendId(4) == (std::pair<uint16_t, std::string>(0, cpuAccId)));
3549     CHECK(counterIdMap.GetBackendId(5) == (std::pair<uint16_t, std::string>(1, cpuAccId)));
3550 
3551     CHECK(counterIdMap.GetGlobalId(0, cpuRefId) == 0);
3552     CHECK(counterIdMap.GetGlobalId(1, cpuRefId) == 1);
3553     CHECK(counterIdMap.GetGlobalId(2, cpuRefId) == 2);
3554     CHECK(counterIdMap.GetGlobalId(3, cpuRefId) == 3);
3555     CHECK(counterIdMap.GetGlobalId(0, cpuAccId) == 4);
3556     CHECK(counterIdMap.GetGlobalId(1, cpuAccId) == 5);
3557 }
3558 
3559 TEST_CASE("CheckRegisterBackendCounters")
3560 {
3561     uint16_t globalCounterIds = INFERENCES_RUN;
3562     std::string cpuRefId(GetComputeDeviceAsCString(armnn::Compute::CpuRef));
3563 
3564     // Reset the profiling service to the uninitialized state
3565     ProfilingOptions options;
3566     options.m_EnableProfiling          = true;
3567     armnn::ArmNNProfilingServiceInitialiser initialiser;
3568     ProfilingService profilingService(arm::pipe::MAX_ARMNN_COUNTER,
3569                                       initialiser,
3570                                       arm::pipe::ARMNN_SOFTWARE_INFO,
3571                                       arm::pipe::ARMNN_SOFTWARE_VERSION,
3572                                       arm::pipe::ARMNN_HARDWARE_VERSION);
3573     profilingService.ResetExternalProfilingOptions(options, true);
3574 
3575     RegisterBackendCounters registerBackendCounters(globalCounterIds, cpuRefId, profilingService);
3576 
3577 
3578 
3579     CHECK(profilingService.GetCounterDirectory().GetCategories().empty());
3580     registerBackendCounters.RegisterCategory("categoryOne");
3581     auto categoryOnePtr = profilingService.GetCounterDirectory().GetCategory("categoryOne");
3582     CHECK(categoryOnePtr);
3583 
3584     CHECK(profilingService.GetCounterDirectory().GetDevices().empty());
3585     globalCounterIds = registerBackendCounters.RegisterDevice("deviceOne");
3586     auto deviceOnePtr = profilingService.GetCounterDirectory().GetDevice(globalCounterIds);
3587     CHECK(deviceOnePtr);
3588     CHECK(deviceOnePtr->m_Name == "deviceOne");
3589 
3590     CHECK(profilingService.GetCounterDirectory().GetCounterSets().empty());
3591     globalCounterIds = registerBackendCounters.RegisterCounterSet("counterSetOne");
3592     auto counterSetOnePtr = profilingService.GetCounterDirectory().GetCounterSet(globalCounterIds);
3593     CHECK(counterSetOnePtr);
3594     CHECK(counterSetOnePtr->m_Name == "counterSetOne");
3595 
3596     uint16_t newGlobalCounterId = registerBackendCounters.RegisterCounter(0,
3597                                                                           "categoryOne",
3598                                                                           0,
3599                                                                           0,
3600                                                                           1.f,
3601                                                                           "CounterOne",
3602                                                                           "first test counter");
3603     CHECK((newGlobalCounterId = INFERENCES_RUN + 1));
3604     uint16_t mappedGlobalId = profilingService.GetCounterMappings().GetGlobalId(0, cpuRefId);
3605     CHECK(mappedGlobalId == newGlobalCounterId);
3606     auto backendMapping = profilingService.GetCounterMappings().GetBackendId(newGlobalCounterId);
3607     CHECK(backendMapping.first == 0);
3608     CHECK(backendMapping.second == cpuRefId);
3609 
3610     // Reset the profiling service to stop any running thread
3611     options.m_EnableProfiling = false;
3612     profilingService.ResetExternalProfilingOptions(options, true);
3613 }
3614 
3615 TEST_CASE("CheckCounterStatusQuery")
3616 {
3617     ProfilingOptions options;
3618     options.m_EnableProfiling = true;
3619 
3620     // Reset the profiling service to the uninitialized state
3621     armnn::ArmNNProfilingServiceInitialiser initialiser;
3622     ProfilingService profilingService(arm::pipe::MAX_ARMNN_COUNTER,
3623                                       initialiser,
3624                                       arm::pipe::ARMNN_SOFTWARE_INFO,
3625                                       arm::pipe::ARMNN_SOFTWARE_VERSION,
3626                                       arm::pipe::ARMNN_HARDWARE_VERSION);
3627     profilingService.ResetExternalProfilingOptions(options, true);
3628 
3629     const std::string cpuRefId(GetComputeDeviceAsCString(armnn::Compute::CpuRef));
3630     const std::string cpuAccId(GetComputeDeviceAsCString(armnn::Compute::CpuAcc));
3631 
3632     // Create BackendProfiling for each backend
3633     BackendProfiling backendProfilingCpuRef(options, profilingService, cpuRefId);
3634     BackendProfiling backendProfilingCpuAcc(options, profilingService, cpuAccId);
3635 
3636     uint16_t initialNumGlobalCounterIds = INFERENCES_RUN;
3637 
3638     // Create RegisterBackendCounters for CpuRef
3639     RegisterBackendCounters registerBackendCountersCpuRef(initialNumGlobalCounterIds, cpuRefId, profilingService);
3640 
3641     // Create 'testCategory' in CounterDirectory (backend agnostic)
3642     CHECK(profilingService.GetCounterDirectory().GetCategories().empty());
3643     registerBackendCountersCpuRef.RegisterCategory("testCategory");
3644     auto categoryOnePtr = profilingService.GetCounterDirectory().GetCategory("testCategory");
3645     CHECK(categoryOnePtr);
3646 
3647     // Counters:
3648     // Global | Local | Backend
3649     //    5   |   0   | CpuRef
3650     //    6   |   1   | CpuRef
3651     //    7   |   1   | CpuAcc
3652 
3653     std::vector<uint16_t> cpuRefCounters = {0, 1};
3654     std::vector<uint16_t> cpuAccCounters = {0};
3655 
3656     // Register the backend counters for CpuRef and validate GetGlobalId and GetBackendId
3657     uint16_t currentNumGlobalCounterIds = registerBackendCountersCpuRef.RegisterCounter(
3658             0, "testCategory", 0, 0, 1.f, "CpuRefCounter0", "Zeroth CpuRef Counter");
3659     CHECK(currentNumGlobalCounterIds == initialNumGlobalCounterIds + 1);
3660     uint16_t mappedGlobalId = profilingService.GetCounterMappings().GetGlobalId(0, cpuRefId);
3661     CHECK(mappedGlobalId == currentNumGlobalCounterIds);
3662     auto backendMapping = profilingService.GetCounterMappings().GetBackendId(currentNumGlobalCounterIds);
3663     CHECK(backendMapping.first == 0);
3664     CHECK(backendMapping.second == cpuRefId);
3665 
3666     currentNumGlobalCounterIds = registerBackendCountersCpuRef.RegisterCounter(
3667             1, "testCategory", 0, 0, 1.f, "CpuRefCounter1", "First CpuRef Counter");
3668     CHECK(currentNumGlobalCounterIds == initialNumGlobalCounterIds + 2);
3669     mappedGlobalId = profilingService.GetCounterMappings().GetGlobalId(1, cpuRefId);
3670     CHECK(mappedGlobalId == currentNumGlobalCounterIds);
3671     backendMapping = profilingService.GetCounterMappings().GetBackendId(currentNumGlobalCounterIds);
3672     CHECK(backendMapping.first == 1);
3673     CHECK(backendMapping.second == cpuRefId);
3674 
3675     // Create RegisterBackendCounters for CpuAcc
3676     RegisterBackendCounters registerBackendCountersCpuAcc(currentNumGlobalCounterIds, cpuAccId, profilingService);
3677 
3678     // Register the backend counter for CpuAcc and validate GetGlobalId and GetBackendId
3679     currentNumGlobalCounterIds = registerBackendCountersCpuAcc.RegisterCounter(
3680             0, "testCategory", 0, 0, 1.f, "CpuAccCounter0", "Zeroth CpuAcc Counter");
3681     CHECK(currentNumGlobalCounterIds == initialNumGlobalCounterIds + 3);
3682     mappedGlobalId = profilingService.GetCounterMappings().GetGlobalId(0, cpuAccId);
3683     CHECK(mappedGlobalId == currentNumGlobalCounterIds);
3684     backendMapping = profilingService.GetCounterMappings().GetBackendId(currentNumGlobalCounterIds);
3685     CHECK(backendMapping.first == 0);
3686     CHECK(backendMapping.second == cpuAccId);
3687 
3688     // Create vectors for active counters
3689     const std::vector<uint16_t> activeGlobalCounterIds = {5}; // CpuRef(0) activated
3690     const std::vector<uint16_t> newActiveGlobalCounterIds = {6, 7}; // CpuRef(0) and CpuAcc(1) activated
3691 
3692     const uint32_t capturePeriod = 200;
3693     const uint32_t newCapturePeriod = 100;
3694 
3695     // Set capture period and active counters in CaptureData
3696     profilingService.SetCaptureData(capturePeriod, activeGlobalCounterIds, {});
3697 
3698     // Get vector of active counters for CpuRef and CpuAcc backends
3699     std::vector<CounterStatus> cpuRefCounterStatus = backendProfilingCpuRef.GetActiveCounters();
3700     std::vector<CounterStatus> cpuAccCounterStatus = backendProfilingCpuAcc.GetActiveCounters();
3701     CHECK_EQ(cpuRefCounterStatus.size(), 1);
3702     CHECK_EQ(cpuAccCounterStatus.size(), 0);
3703 
3704     // Check active CpuRef counter
3705     CHECK_EQ(cpuRefCounterStatus[0].m_GlobalCounterId, activeGlobalCounterIds[0]);
3706     CHECK_EQ(cpuRefCounterStatus[0].m_BackendCounterId, cpuRefCounters[0]);
3707     CHECK_EQ(cpuRefCounterStatus[0].m_SamplingRateInMicroseconds, capturePeriod);
3708     CHECK_EQ(cpuRefCounterStatus[0].m_Enabled, true);
3709 
3710     // Check inactive CpuRef counter
3711     CounterStatus inactiveCpuRefCounter = backendProfilingCpuRef.GetCounterStatus(cpuRefCounters[1]);
3712     CHECK_EQ(inactiveCpuRefCounter.m_GlobalCounterId, 6);
3713     CHECK_EQ(inactiveCpuRefCounter.m_BackendCounterId, cpuRefCounters[1]);
3714     CHECK_EQ(inactiveCpuRefCounter.m_SamplingRateInMicroseconds, 0);
3715     CHECK_EQ(inactiveCpuRefCounter.m_Enabled, false);
3716 
3717     // Check inactive CpuAcc counter
3718     CounterStatus inactiveCpuAccCounter = backendProfilingCpuAcc.GetCounterStatus(cpuAccCounters[0]);
3719     CHECK_EQ(inactiveCpuAccCounter.m_GlobalCounterId, 7);
3720     CHECK_EQ(inactiveCpuAccCounter.m_BackendCounterId, cpuAccCounters[0]);
3721     CHECK_EQ(inactiveCpuAccCounter.m_SamplingRateInMicroseconds, 0);
3722     CHECK_EQ(inactiveCpuAccCounter.m_Enabled, false);
3723 
3724     // Set new capture period and new active counters in CaptureData
3725     profilingService.SetCaptureData(newCapturePeriod, newActiveGlobalCounterIds, {});
3726 
3727     // Get vector of active counters for CpuRef and CpuAcc backends
3728     cpuRefCounterStatus = backendProfilingCpuRef.GetActiveCounters();
3729     cpuAccCounterStatus = backendProfilingCpuAcc.GetActiveCounters();
3730     CHECK_EQ(cpuRefCounterStatus.size(), 1);
3731     CHECK_EQ(cpuAccCounterStatus.size(), 1);
3732 
3733     // Check active CpuRef counter
3734     CHECK_EQ(cpuRefCounterStatus[0].m_GlobalCounterId, newActiveGlobalCounterIds[0]);
3735     CHECK_EQ(cpuRefCounterStatus[0].m_BackendCounterId, cpuRefCounters[1]);
3736     CHECK_EQ(cpuRefCounterStatus[0].m_SamplingRateInMicroseconds, newCapturePeriod);
3737     CHECK_EQ(cpuRefCounterStatus[0].m_Enabled, true);
3738 
3739     // Check active CpuAcc counter
3740     CHECK_EQ(cpuAccCounterStatus[0].m_GlobalCounterId, newActiveGlobalCounterIds[1]);
3741     CHECK_EQ(cpuAccCounterStatus[0].m_BackendCounterId, cpuAccCounters[0]);
3742     CHECK_EQ(cpuAccCounterStatus[0].m_SamplingRateInMicroseconds, newCapturePeriod);
3743     CHECK_EQ(cpuAccCounterStatus[0].m_Enabled, true);
3744 
3745     // Check inactive CpuRef counter
3746     inactiveCpuRefCounter = backendProfilingCpuRef.GetCounterStatus(cpuRefCounters[0]);
3747     CHECK_EQ(inactiveCpuRefCounter.m_GlobalCounterId, 5);
3748     CHECK_EQ(inactiveCpuRefCounter.m_BackendCounterId, cpuRefCounters[0]);
3749     CHECK_EQ(inactiveCpuRefCounter.m_SamplingRateInMicroseconds, 0);
3750     CHECK_EQ(inactiveCpuRefCounter.m_Enabled, false);
3751 
3752     // Reset the profiling service to stop any running thread
3753     options.m_EnableProfiling = false;
3754     profilingService.ResetExternalProfilingOptions(options, true);
3755 }
3756 
3757 TEST_CASE("CheckRegisterCounters")
3758 {
3759     ProfilingOptions options;
3760     options.m_EnableProfiling = true;
3761     MockBufferManager mockBuffer(1024);
3762 
3763     CaptureData captureData;
3764 
3765     armnn::ArmNNProfilingServiceInitialiser initialiser;
3766     MockProfilingService mockProfilingService(
3767         arm::pipe::MAX_ARMNN_COUNTER, initialiser, mockBuffer, options.m_EnableProfiling, captureData);
3768     std::string cpuRefId(GetComputeDeviceAsCString(armnn::Compute::CpuRef));
3769 
3770     mockProfilingService.RegisterMapping(6, 0, cpuRefId);
3771     mockProfilingService.RegisterMapping(7, 1, cpuRefId);
3772     mockProfilingService.RegisterMapping(8, 2, cpuRefId);
3773 
3774     BackendProfiling backendProfiling(options,
3775                                                         mockProfilingService,
3776                                                         cpuRefId);
3777 
3778     Timestamp timestamp;
3779     timestamp.timestamp = 1000998;
3780     timestamp.counterValues.emplace_back(0, 700);
3781     timestamp.counterValues.emplace_back(2, 93);
3782     std::vector<Timestamp> timestamps;
3783     timestamps.push_back(timestamp);
3784     backendProfiling.ReportCounters(timestamps);
3785 
3786     auto readBuffer = mockBuffer.GetReadableBuffer();
3787 
3788     uint32_t headerWord0 = ReadUint32(readBuffer, 0);
3789     uint32_t headerWord1 = ReadUint32(readBuffer, 4);
3790     uint64_t readTimestamp = ReadUint64(readBuffer, 8);
3791 
3792     CHECK(((headerWord0 >> 26) & 0x0000003F) == 3); // packet family
3793     CHECK(((headerWord0 >> 19) & 0x0000007F) == 0); // packet class
3794     CHECK(((headerWord0 >> 16) & 0x00000007) == 0); // packet type
3795     CHECK(headerWord1 == 20);                       // data length
3796     CHECK(1000998 == readTimestamp);                // capture period
3797 
3798     uint32_t offset = 16;
3799     // Check Counter Index
3800     uint16_t readIndex = ReadUint16(readBuffer, offset);
3801     CHECK(6 == readIndex);
3802 
3803     // Check Counter Value
3804     offset += 2;
3805     uint32_t readValue = ReadUint32(readBuffer, offset);
3806     CHECK(700 == readValue);
3807 
3808     // Check Counter Index
3809     offset += 4;
3810     readIndex = ReadUint16(readBuffer, offset);
3811     CHECK(8 == readIndex);
3812 
3813     // Check Counter Value
3814     offset += 2;
3815     readValue = ReadUint32(readBuffer, offset);
3816     CHECK(93 == readValue);
3817 }
3818 
3819 TEST_CASE("CheckFileFormat") {
3820     // Locally reduce log level to "Warning", as this test needs to parse a warning message from the standard output
3821     LogLevelSwapper logLevelSwapper(arm::pipe::LogSeverity::Warning);
3822 
3823     // Redirect the output to a local stream so that we can parse the warning message
3824     std::stringstream ss;
3825     StreamRedirector streamRedirector(std::cout, ss.rdbuf());
3826 
3827     // Create profiling options.
3828     ProfilingOptions options;
3829     options.m_EnableProfiling = true;
3830     // Check the default value set to binary
3831     CHECK(options.m_FileFormat == "binary");
3832 
3833     // Change file format to an unsupported value
3834     options.m_FileFormat = "json";
3835     // Enable the profiling service
3836     armnn::ArmNNProfilingServiceInitialiser initialiser;
3837     ProfilingService profilingService(arm::pipe::MAX_ARMNN_COUNTER,
3838                                       initialiser,
3839                                       arm::pipe::ARMNN_SOFTWARE_INFO,
3840                                       arm::pipe::ARMNN_SOFTWARE_VERSION,
3841                                       arm::pipe::ARMNN_HARDWARE_VERSION);
3842     profilingService.ResetExternalProfilingOptions(options, true);
3843     // Start the command handler and the send thread
3844     profilingService.Update();
3845     CHECK(profilingService.GetCurrentState()==ProfilingState::NotConnected);
3846 
3847     // When Update is called and the current state is ProfilingState::NotConnected
3848     // an exception will be raised from GetProfilingConnection and displayed as warning in the output local stream
3849     profilingService.Update();
3850 
3851     streamRedirector.CancelRedirect();
3852 
3853     // Check that the expected error has occurred and logged to the standard output
3854     if (ss.str().find("Unsupported profiling file format, only binary is supported") == std::string::npos)
3855     {
3856         std::cout << ss.str();
3857         FAIL("Expected string not found.");
3858     }
3859 }
3860 
3861 }
3862