1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/extension/pytree/pytree.h>
10
11 #include <gtest/gtest.h>
12 #include <string>
13
14 using ::executorch::extension::pytree::ContainerHandle;
15 using ::executorch::extension::pytree::Key;
16 using ::executorch::extension::pytree::Kind;
17 using ::executorch::extension::pytree::unflatten;
18
19 using Leaf = int32_t;
20
TEST(PyTreeTest,List)21 TEST(PyTreeTest, List) {
22 Leaf items[2] = {11, 12};
23 std::string spec = "L2#1#1($,$)";
24 auto c = unflatten(spec, items);
25 ASSERT_TRUE(c.isList());
26 ASSERT_EQ(c.size(), 2);
27
28 const auto& child0 = c[0];
29 const auto& child1 = c[1];
30
31 ASSERT_TRUE(child0.isLeaf());
32 ASSERT_TRUE(child1.isLeaf());
33 ASSERT_EQ(child0, 11);
34 ASSERT_EQ(child1, 12);
35 }
36
TEST(PyTreeTest,Tuple)37 TEST(PyTreeTest, Tuple) {
38 std::string spec = "T1#1($)";
39 Leaf items[1] = {11};
40 auto c = unflatten(spec, items);
41 ASSERT_TRUE(c.isTuple());
42 ASSERT_EQ(c.size(), 1);
43
44 const auto& child0 = c[0];
45
46 ASSERT_TRUE(child0.isLeaf());
47 ASSERT_EQ(child0, 11);
48 }
49
TEST(PyTreeTest,Dict)50 TEST(PyTreeTest, Dict) {
51 std::string spec = "D2#1#1('key0':$,'key1':$)";
52 Leaf items[2] = {11, 12};
53 auto c = unflatten(spec, items);
54 ASSERT_TRUE(c.isDict());
55 ASSERT_EQ(c.size(), 2);
56
57 const auto& key0 = c.key(0);
58 const auto& key1 = c.key(1);
59
60 ASSERT_TRUE(key0 == Key("key0"));
61 ASSERT_TRUE(key1 == Key("key1"));
62
63 const auto& child0 = c[0];
64 const auto& child1 = c[1];
65 ASSERT_TRUE(child0.isLeaf());
66 ASSERT_TRUE(child1.isLeaf());
67 ASSERT_EQ(child0, 11);
68 ASSERT_EQ(child1, 12);
69
70 const auto& ckey0 = c.at("key0");
71 ASSERT_EQ(child0, ckey0);
72
73 ASSERT_EQ(c.at("key0"), 11);
74 ASSERT_EQ(c.at("key1"), 12);
75 }
76
TEST(PyTreeTest,Leaf)77 TEST(PyTreeTest, Leaf) {
78 std::string spec = "$";
79 Leaf items[2] = {11};
80 auto c = unflatten(spec, items);
81 ASSERT_TRUE(c.isLeaf());
82 ASSERT_EQ(c, 11);
83 }
84
TEST(PyTreeTest,DictWithList)85 TEST(PyTreeTest, DictWithList) {
86 std::string spec = "D2#2#1('key0':L2#1#1($,$),'key1':$)";
87 Leaf items[3] = {11, 12, 13};
88 auto c = unflatten(spec, items);
89 ASSERT_TRUE(c.isDict());
90 ASSERT_EQ(c.size(), 2);
91
92 const auto& key0 = c.key(0);
93 const auto& key1 = c.key(1);
94
95 ASSERT_TRUE(key0 == Key("key0"));
96 ASSERT_TRUE(key1 == Key("key1"));
97
98 const auto& child1 = c[1];
99 ASSERT_TRUE(child1.isLeaf());
100 ASSERT_EQ(child1, 13);
101
102 const auto& list = c[0];
103 ASSERT_TRUE(list.isList());
104 ASSERT_EQ(list.size(), 2);
105
106 const auto& list_child0 = list[0];
107 const auto& list_child1 = list[1];
108
109 ASSERT_TRUE(list_child0.isLeaf());
110 ASSERT_TRUE(list_child1.isLeaf());
111
112 ASSERT_EQ(list_child0, 11);
113 ASSERT_EQ(list_child1, 12);
114 }
115
TEST(PyTreeTest,DictKeysStrInt)116 TEST(PyTreeTest, DictKeysStrInt) {
117 std::string spec = "D4#1#1#1#1('key0':$,1:$,23:$,123:$)";
118 Leaf items[4] = {11, 12, 13, 14};
119 auto c = unflatten(spec, items);
120 ASSERT_TRUE(c.isDict());
121 ASSERT_EQ(c.size(), 4);
122
123 const auto& key0 = c.key(0);
124 const auto& key1 = c.key(1);
125
126 ASSERT_TRUE(key0 == Key("key0"));
127 ASSERT_TRUE(key1 == Key(1));
128
129 const auto& child0 = c[0];
130 const auto& child1 = c[1];
131 ASSERT_TRUE(child0.isLeaf());
132 ASSERT_TRUE(child1.isLeaf());
133 ASSERT_EQ(child0, 11);
134 ASSERT_EQ(child1, 12);
135
136 const auto& ckey0 = c.at("key0");
137 ASSERT_EQ(child0, ckey0);
138
139 ASSERT_EQ(c.at(1), 12);
140 ASSERT_EQ(c.at(23), 13);
141 ASSERT_EQ(c.at(123), 14);
142 }
143
TEST(pytree,FlattenDict)144 TEST(pytree, FlattenDict) {
145 Leaf items[3] = {11, 12, 13};
146 auto c = ContainerHandle<Leaf>(Kind::Dict, 3);
147 c[0] = &items[0];
148 c[1] = &items[1];
149 c[2] = &items[2];
150 c.key(0) = 0;
151 c.key(1) = Key("key_1");
152 c.key(2) = Key("key_2");
153 auto p = flatten(c);
154 const auto& leaves = p.first;
155 ASSERT_EQ(leaves.size(), 3);
156 for (size_t i = 0; i < 3; ++i) {
157 ASSERT_EQ(*leaves[i], items[i]);
158 }
159 }
160
TEST(pytree,FlattenNestedDict)161 TEST(pytree, FlattenNestedDict) {
162 Leaf items[5] = {11, 12, 13, 14, 15};
163 auto c = ContainerHandle<Leaf>(Kind::Dict, 3);
164 auto c2 = ContainerHandle<Leaf>(Kind::Dict, 3);
165 c2[0] = &items[2];
166 c2[1] = &items[3];
167 c2[2] = &items[4];
168 c2.key(0) = Key("c2_key_0");
169 c2.key(1) = Key("c2_key_1");
170 c2.key(2) = Key("c2_key_2");
171
172 c[0] = &items[0];
173 c[1] = &items[1];
174 c[2] = std::move(c2);
175 c.key(0) = 0;
176 c.key(1) = Key("key_1");
177 c.key(2) = Key("key_2");
178
179 auto p = flatten(c);
180 const auto& leaves = p.first;
181 ASSERT_EQ(leaves.size(), 5);
182 for (size_t i = 0; i < 5; ++i) {
183 ASSERT_EQ(*leaves[i], items[i]);
184 }
185 }
186