xref: /aosp_15_r20/external/executorch/extension/pytree/test/test_pytree.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/extension/pytree/pytree.h>
10 
11 #include <gtest/gtest.h>
12 #include <string>
13 
14 using ::executorch::extension::pytree::ContainerHandle;
15 using ::executorch::extension::pytree::Key;
16 using ::executorch::extension::pytree::Kind;
17 using ::executorch::extension::pytree::unflatten;
18 
19 using Leaf = int32_t;
20 
TEST(PyTreeTest,List)21 TEST(PyTreeTest, List) {
22   Leaf items[2] = {11, 12};
23   std::string spec = "L2#1#1($,$)";
24   auto c = unflatten(spec, items);
25   ASSERT_TRUE(c.isList());
26   ASSERT_EQ(c.size(), 2);
27 
28   const auto& child0 = c[0];
29   const auto& child1 = c[1];
30 
31   ASSERT_TRUE(child0.isLeaf());
32   ASSERT_TRUE(child1.isLeaf());
33   ASSERT_EQ(child0, 11);
34   ASSERT_EQ(child1, 12);
35 }
36 
TEST(PyTreeTest,Tuple)37 TEST(PyTreeTest, Tuple) {
38   std::string spec = "T1#1($)";
39   Leaf items[1] = {11};
40   auto c = unflatten(spec, items);
41   ASSERT_TRUE(c.isTuple());
42   ASSERT_EQ(c.size(), 1);
43 
44   const auto& child0 = c[0];
45 
46   ASSERT_TRUE(child0.isLeaf());
47   ASSERT_EQ(child0, 11);
48 }
49 
TEST(PyTreeTest,Dict)50 TEST(PyTreeTest, Dict) {
51   std::string spec = "D2#1#1('key0':$,'key1':$)";
52   Leaf items[2] = {11, 12};
53   auto c = unflatten(spec, items);
54   ASSERT_TRUE(c.isDict());
55   ASSERT_EQ(c.size(), 2);
56 
57   const auto& key0 = c.key(0);
58   const auto& key1 = c.key(1);
59 
60   ASSERT_TRUE(key0 == Key("key0"));
61   ASSERT_TRUE(key1 == Key("key1"));
62 
63   const auto& child0 = c[0];
64   const auto& child1 = c[1];
65   ASSERT_TRUE(child0.isLeaf());
66   ASSERT_TRUE(child1.isLeaf());
67   ASSERT_EQ(child0, 11);
68   ASSERT_EQ(child1, 12);
69 
70   const auto& ckey0 = c.at("key0");
71   ASSERT_EQ(child0, ckey0);
72 
73   ASSERT_EQ(c.at("key0"), 11);
74   ASSERT_EQ(c.at("key1"), 12);
75 }
76 
TEST(PyTreeTest,Leaf)77 TEST(PyTreeTest, Leaf) {
78   std::string spec = "$";
79   Leaf items[2] = {11};
80   auto c = unflatten(spec, items);
81   ASSERT_TRUE(c.isLeaf());
82   ASSERT_EQ(c, 11);
83 }
84 
TEST(PyTreeTest,DictWithList)85 TEST(PyTreeTest, DictWithList) {
86   std::string spec = "D2#2#1('key0':L2#1#1($,$),'key1':$)";
87   Leaf items[3] = {11, 12, 13};
88   auto c = unflatten(spec, items);
89   ASSERT_TRUE(c.isDict());
90   ASSERT_EQ(c.size(), 2);
91 
92   const auto& key0 = c.key(0);
93   const auto& key1 = c.key(1);
94 
95   ASSERT_TRUE(key0 == Key("key0"));
96   ASSERT_TRUE(key1 == Key("key1"));
97 
98   const auto& child1 = c[1];
99   ASSERT_TRUE(child1.isLeaf());
100   ASSERT_EQ(child1, 13);
101 
102   const auto& list = c[0];
103   ASSERT_TRUE(list.isList());
104   ASSERT_EQ(list.size(), 2);
105 
106   const auto& list_child0 = list[0];
107   const auto& list_child1 = list[1];
108 
109   ASSERT_TRUE(list_child0.isLeaf());
110   ASSERT_TRUE(list_child1.isLeaf());
111 
112   ASSERT_EQ(list_child0, 11);
113   ASSERT_EQ(list_child1, 12);
114 }
115 
TEST(PyTreeTest,DictKeysStrInt)116 TEST(PyTreeTest, DictKeysStrInt) {
117   std::string spec = "D4#1#1#1#1('key0':$,1:$,23:$,123:$)";
118   Leaf items[4] = {11, 12, 13, 14};
119   auto c = unflatten(spec, items);
120   ASSERT_TRUE(c.isDict());
121   ASSERT_EQ(c.size(), 4);
122 
123   const auto& key0 = c.key(0);
124   const auto& key1 = c.key(1);
125 
126   ASSERT_TRUE(key0 == Key("key0"));
127   ASSERT_TRUE(key1 == Key(1));
128 
129   const auto& child0 = c[0];
130   const auto& child1 = c[1];
131   ASSERT_TRUE(child0.isLeaf());
132   ASSERT_TRUE(child1.isLeaf());
133   ASSERT_EQ(child0, 11);
134   ASSERT_EQ(child1, 12);
135 
136   const auto& ckey0 = c.at("key0");
137   ASSERT_EQ(child0, ckey0);
138 
139   ASSERT_EQ(c.at(1), 12);
140   ASSERT_EQ(c.at(23), 13);
141   ASSERT_EQ(c.at(123), 14);
142 }
143 
TEST(pytree,FlattenDict)144 TEST(pytree, FlattenDict) {
145   Leaf items[3] = {11, 12, 13};
146   auto c = ContainerHandle<Leaf>(Kind::Dict, 3);
147   c[0] = &items[0];
148   c[1] = &items[1];
149   c[2] = &items[2];
150   c.key(0) = 0;
151   c.key(1) = Key("key_1");
152   c.key(2) = Key("key_2");
153   auto p = flatten(c);
154   const auto& leaves = p.first;
155   ASSERT_EQ(leaves.size(), 3);
156   for (size_t i = 0; i < 3; ++i) {
157     ASSERT_EQ(*leaves[i], items[i]);
158   }
159 }
160 
TEST(pytree,FlattenNestedDict)161 TEST(pytree, FlattenNestedDict) {
162   Leaf items[5] = {11, 12, 13, 14, 15};
163   auto c = ContainerHandle<Leaf>(Kind::Dict, 3);
164   auto c2 = ContainerHandle<Leaf>(Kind::Dict, 3);
165   c2[0] = &items[2];
166   c2[1] = &items[3];
167   c2[2] = &items[4];
168   c2.key(0) = Key("c2_key_0");
169   c2.key(1) = Key("c2_key_1");
170   c2.key(2) = Key("c2_key_2");
171 
172   c[0] = &items[0];
173   c[1] = &items[1];
174   c[2] = std::move(c2);
175   c.key(0) = 0;
176   c.key(1) = Key("key_1");
177   c.key(2) = Key("key_2");
178 
179   auto p = flatten(c);
180   const auto& leaves = p.first;
181   ASSERT_EQ(leaves.size(), 5);
182   for (size_t i = 0; i < 5; ++i) {
183     ASSERT_EQ(*leaves[i], items[i]);
184   }
185 }
186