1 //
2 // Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #include <armnn/utility/TransformIterator.hpp>
7
8 #include <doctest/doctest.h>
9 #include <vector>
10 #include <iostream>
11
12 using namespace armnn;
13
14 TEST_SUITE("TransformIteratorSuite")
15 {
16 namespace
17 {
18
square(const int val)19 static int square(const int val)
20 {
21 return val * val;
22 }
23
concat(const std::string val)24 static std::string concat(const std::string val)
25 {
26 return val + "a";
27 }
28
29 TEST_CASE("TransformIteratorTest")
30 {
31 struct WrapperTestClass
32 {
begin__anona36a5b8b0111::WrapperTestClass33 TransformIterator<decltype(&square), std::vector<int>::const_iterator> begin() const
34 {
35 return { m_Vec.begin(), &square };
36 }
37
end__anona36a5b8b0111::WrapperTestClass38 TransformIterator<decltype(&square), std::vector<int>::const_iterator> end() const
39 {
40 return { m_Vec.end(), &square };
41 }
42
43 const std::vector<int> m_Vec{1, 2, 3, 4, 5};
44 };
45
46 struct WrapperStringClass
47 {
begin__anona36a5b8b0111::WrapperStringClass48 TransformIterator<decltype(&concat), std::vector<std::string>::const_iterator> begin() const
49 {
50 return { m_Vec.begin(), &concat };
51 }
52
end__anona36a5b8b0111::WrapperStringClass53 TransformIterator<decltype(&concat), std::vector<std::string>::const_iterator> end() const
54 {
55 return { m_Vec.end(), &concat };
56 }
57
58 const std::vector<std::string> m_Vec{"a", "b", "c"};
59 };
60
61 WrapperStringClass wrapperStringClass;
62 WrapperTestClass wrapperTestClass;
63 int i = 1;
64
65 for(auto val : wrapperStringClass)
66 {
67 CHECK(val != "e");
68 i++;
69 }
70
71 i = 1;
72 for(auto val : wrapperTestClass)
73 {
74 CHECK(val == square(i));
75 i++;
76 }
77
78 i = 1;
79 // Check original vector is unchanged
80 for(auto val : wrapperTestClass.m_Vec)
81 {
82 CHECK(val == i);
83 i++;
84 }
85
86 std::vector<int> originalVec{1, 2, 3, 4, 5};
87
88 auto transformBegin = MakeTransformIterator(originalVec.begin(), &square);
89 auto transformEnd = MakeTransformIterator(originalVec.end(), &square);
90
91 std::vector<int> transformedVec(transformBegin, transformEnd);
92
93 i = 1;
94 for(auto val : transformedVec)
95 {
96 CHECK(val == square(i));
97 i++;
98 }
99 }
100
101 }
102
103 }
104