1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import unittest 8from collections import namedtuple 9from typing import Any, Dict 10 11import torch 12 13# @manual=//executorch/extension/pytree:pybindings 14from executorch.extension.pytree import ( 15 broadcast_to_and_flatten, 16 register_custom, 17 tree_flatten, 18 tree_map, 19 tree_unflatten, 20 TreeSpec, 21) 22 23 24# pyre-fixme[11]: Annotation `TreeSpec` is not defined as a type. 25def _spec(o: Any) -> TreeSpec: 26 # pyre-fixme[16]: Module `pytree` has no attribute `tree_flatten`. 27 _, spec = tree_flatten(o) 28 return spec 29 30 31# Constructs string representation of pytree spec of type specified by type_char (can be 'T' for tuple, 'L' for List) argument, that contains n children, each with single leaf. 32# e.g. ('T', 3) -> 'T3#1#1#1($,$,$)' 33def _spec_str(type_char, n: int) -> str: 34 spec = type_char + str(n) 35 for _ in range(n): 36 spec += "#1" 37 spec += "(" 38 for i in range(n): 39 if i > 0: 40 spec += "," 41 spec += "$" 42 spec += ")" 43 return spec 44 45 46# Constructs string representation of pytree spec of Dict, keys can be str or int, every value is leaf. 47# e.g.: {'a': 1, 2: 2} -> D2#1#1('a':$,2:$) 48def _spec_str_dict(d: Dict[Any, Any]) -> str: 49 n = len(d) 50 spec = "D" + str(n) 51 for _ in range(n): 52 spec += "#1" 53 spec += "(" 54 i = 0 55 for key in d.keys(): 56 if i > 0: 57 spec += "," 58 if isinstance(key, str): 59 spec += "'" + key + "'" 60 else: 61 spec += str(key) 62 spec += ":$" 63 i += 1 64 spec += ")" 65 return spec 66 67 68class TestPytree(unittest.TestCase): 69 def test(self): 70 SPEC = "D4#2#1#2#2('a':L2#1#1($,$),1:$,2:T2#1#1($,$),'str':D2#1#1('str':$,'str2':$))" 71 d = {} 72 d["a"] = [777, 1] 73 d[1] = 4 74 d[2] = ("ta", 2) 75 d["str"] = {"str": 23, "str2": "47str"} 76 (leaves, pytree) = tree_flatten(d) 77 self.assertEqual(leaves, [777, 1, 4, "ta", 2, 23, "47str"]) 78 pytree_str = pytree.to_str() 79 self.assertEqual(pytree_str, SPEC) 80 81 leaves_test = [] 82 for i in range(len(leaves)): 83 if i % 2 == 0: 84 leaves_test.append(i + 13) 85 else: 86 leaves_test.append(str(i + 13)) 87 88 tree_test = pytree.tree_unflatten(leaves_test) 89 self.assertEqual( 90 tree_test, 91 {"a": [13, "14"], 1: 15, 2: ("16", 17), "str": {"str": "18", "str2": 19}}, 92 ) 93 94 pytree_from = TreeSpec.from_str(SPEC) 95 spec_str_to = pytree_from.to_str() 96 self.assertEqual(SPEC, spec_str_to) 97 98 def test_extract_nested_list(self): 99 nested_struct = (1, 2, [3, 4]) 100 (_, pytree) = tree_flatten(nested_struct) 101 self.assertEqual(pytree.to_str(), "T3#1#1#2($,$,L2#1#1($,$))") 102 103 def test_extract_nested_dict(self): 104 nested_struct = (1, 2, {3: 4, "str": 6}) 105 (_, pytree) = tree_flatten(nested_struct) 106 self.assertEqual(pytree.to_str(), "T3#1#1#2($,$,D2#1#1(3:$,'str':$))") 107 108 def test_extracted_scalar(self): 109 struct = 4 110 (_, pytree) = tree_flatten(struct) 111 self.assertEqual(pytree.to_str(), "$") 112 113 def test_map(self): 114 struct = (1, 2, [3, 4]) 115 struct_map = tree_map(lambda x: 2 * x, struct) 116 self.assertEqual(struct_map, (2, 4, [6, 8])) 117 118 def test_treespec_equality(self): 119 self.assertTrue(TreeSpec.from_str("$") == TreeSpec.from_str("$")) 120 self.assertTrue(_spec([1]) == TreeSpec.from_str("L1#1($)")) 121 self.assertTrue(_spec((1)) != _spec([1])) 122 self.assertTrue(_spec((1)) == _spec((2))) 123 124 def test_flatten_unflatten_leaf(self): 125 def run_test_with_leaf(leaf): 126 values, treespec = tree_flatten(leaf) 127 self.assertEqual(values, [leaf]) 128 self.assertEqual(treespec, TreeSpec.from_str("$")) 129 130 unflattened = tree_unflatten(values, treespec) 131 self.assertEqual(unflattened, leaf) 132 133 run_test_with_leaf(1) 134 run_test_with_leaf(1.0) 135 run_test_with_leaf(None) 136 run_test_with_leaf(bool) 137 138 def test_flatten_unflatten_list(self): 139 def run_test(lst): 140 spec = _spec_str("L", len(lst)) 141 142 expected_spec = TreeSpec.from_str(spec) 143 values, treespec = tree_flatten(lst) 144 self.assertTrue(isinstance(values, list)) 145 self.assertEqual(values, lst) 146 self.assertEqual(treespec, expected_spec) 147 148 unflattened = tree_unflatten(values, treespec) 149 self.assertEqual(unflattened, lst) 150 self.assertTrue(isinstance(unflattened, list)) 151 152 run_test([]) 153 run_test([1.0, 2]) 154 run_test([torch.tensor([1.0, 2]), 2, 10, 9, 11]) 155 156 def test_flatten_unflatten_tuple(self): 157 def run_test(tup): 158 spec = _spec_str("T", len(tup)) 159 160 expected_spec = TreeSpec.from_str(spec) 161 values, treespec = tree_flatten(tup) 162 self.assertTrue(isinstance(values, list)) 163 self.assertEqual(values, list(tup)) 164 self.assertEqual(treespec, expected_spec) 165 166 unflattened = tree_unflatten(values, treespec) 167 self.assertEqual(unflattened, tup) 168 self.assertTrue(isinstance(unflattened, tuple)) 169 170 run_test(()) 171 run_test((1.0,)) 172 run_test((1.0, 2)) 173 run_test((torch.tensor([1.0, 2]), 2, 10, 9, 11)) 174 175 def test_flatten_unflatten_namedtuple(self): 176 Point = namedtuple("Point", ["x", "y"]) 177 178 def run_test(tup): 179 spec = _spec_str("N", len(tup)) 180 expected_spec = TreeSpec.from_str(spec) 181 182 values, treespec = tree_flatten(tup) 183 self.assertTrue(isinstance(values, list)) 184 185 self.assertEqual(values, list(tup)) 186 self.assertEqual(treespec, expected_spec) 187 188 unflattened = tree_unflatten(values, treespec) 189 self.assertEqual(unflattened, tup) 190 191 run_test(Point(1.0, 2)) 192 run_test(Point(torch.tensor(1.0), 2)) 193 194 def test_flatten_unflatten_torch_namedtuple_return_type(self): 195 x = torch.randn(3, 3) 196 expected = torch.max(x, dim=0) 197 198 values, spec = tree_flatten(expected) 199 result = tree_unflatten(values, spec) 200 201 self.assertEqual(type(result), type(expected)) 202 self.assertEqual(result, expected) 203 204 def test_flatten_unflatten_dict(self): 205 def run_test(d): 206 spec = _spec_str_dict(d) 207 208 values, treespec = tree_flatten(d) 209 self.assertTrue(isinstance(values, list)) 210 self.assertEqual(values, list(d.values())) 211 self.assertEqual(treespec, TreeSpec.from_str(spec)) 212 213 unflattened = tree_unflatten(values, treespec) 214 self.assertEqual(unflattened, d) 215 self.assertTrue(isinstance(unflattened, dict)) 216 217 run_test({}) 218 run_test({"a": 1}) 219 run_test({"abcdefg": torch.randn(2, 3)}) 220 run_test({1: torch.randn(2, 3)}) 221 run_test({"a": 1, "b": 2, "c": torch.randn(2, 3)}) 222 223 def test_flatten_unflatten_nested(self): 224 def run_test(pytree): 225 values, treespec = tree_flatten(pytree) 226 self.assertTrue(isinstance(values, list)) 227 228 unflattened = tree_unflatten(values, treespec) 229 self.assertEqual(unflattened, pytree) 230 231 cases = [ 232 [()], 233 ([],), 234 {"a": ()}, 235 {"a": 0, "b": [{"c": 1}]}, 236 {"a": 0, "b": [1, {"c": 2}, torch.randn(3)], "c": (torch.randn(2, 3), 1)}, 237 ] 238 for case in cases: 239 run_test(case) 240 241 def test_treemap(self): 242 def run_test(pytree): 243 def f(x): 244 return x * 3 245 246 sm1 = sum(map(tree_flatten(pytree)[0], f)) 247 sm2 = tree_flatten(tree_map(f, pytree))[0] 248 self.assertEqual(sm1, sm2) 249 250 def invf(x): 251 return x // 3 252 253 self.assertEqual(tree_flatten(tree_flatten(pytree, f), invf), pytree) 254 255 cases = [ 256 [()], 257 ([],), 258 {"a": ()}, 259 {"a": 1, "b": [{"c": 2}]}, 260 {"a": 0, "b": [2, {"c": 3}, 4], "c": (5, 6)}, 261 ] 262 for case in cases: 263 run_test(case) 264 265 def test_treespec_repr(self): 266 pytree = (0, [0, 0, 0]) 267 _, spec = tree_flatten(pytree) 268 self.assertEqual(repr(spec), "T2#1#3($,L3#1#1#1($,$,$))") 269 270 def test_custom_tree_node(self): 271 class Point(object): 272 def __init__(self, x, y, name): 273 self.x = x 274 self.y = y 275 self.name = name 276 277 def __repr__(self): 278 return "Point(x:{}, y:{}, name: {})".format(self.x, self.y, self.name) 279 280 def custom_flatten(p): 281 children = [p.x, p.y] 282 extra_data = p.name 283 return (children, extra_data) 284 285 def custom_unflatten(children, extra_data): 286 return Point(*children, extra_data) 287 288 register_custom(Point, custom_flatten, custom_unflatten) 289 290 point = Point((1.0, 1.0, 1), 2.0, "point_name") 291 children, spec = tree_flatten(point) 292 point2 = tree_unflatten(children, spec) 293 self.assertEqual(str(point), str(point2)) 294 295 def test_broadcast_to_and_flatten(self): 296 cases = [ 297 (1, (), []), 298 # Same (flat) structures 299 ((1,), (0,), [1]), 300 ([1], [0], [1]), 301 ((1, 2, 3), (0, 0, 0), [1, 2, 3]), 302 ({"a": 1, "b": 2}, {"a": 0, "b": 0}, [1, 2]), 303 # Mismatched (flat) structures 304 ([1], (0,), None), 305 ([1], (0,), None), 306 ((1,), [0], None), 307 ((1, 2, 3), (0, 0), None), 308 ({"a": 1, "b": 2}, {"a": 0}, None), 309 ({"a": 1, "b": 2}, {"a": 0, "c": 0}, None), 310 ({"a": 1, "b": 2}, {"a": 0, "b": 0, "c": 0}, None), 311 # Same (nested) structures 312 ((1, [2, 3]), (0, [0, 0]), [1, 2, 3]), 313 ((1, [(2, 3), 4]), (0, [(0, 0), 0]), [1, 2, 3, 4]), 314 # Mismatched (nested) structures 315 ((1, [2, 3]), (0, (0, 0)), None), 316 ((1, [2, 3]), (0, [0, 0, 0]), None), 317 # Broadcasting single value 318 (1, (0, 0, 0), [1, 1, 1]), 319 (1, [0, 0, 0], [1, 1, 1]), 320 (1, {"a": 0, "b": 0}, [1, 1]), 321 (1, (0, [0, [0]], 0), [1, 1, 1, 1]), 322 (1, (0, [0, [0, [], [[[0]]]]], 0), [1, 1, 1, 1, 1]), 323 # Broadcast multiple things 324 ((1, 2), ([0, 0, 0], [0, 0]), [1, 1, 1, 2, 2]), 325 ((1, 2), ([0, [0, 0], 0], [0, 0]), [1, 1, 1, 1, 2, 2]), 326 (([1, 2, 3], 4), ([0, [0, 0], 0], [0, 0]), [1, 2, 2, 3, 4, 4]), 327 ] 328 for pytree, to_pytree, expected in cases: 329 _, to_spec = tree_flatten(to_pytree) 330 result = broadcast_to_and_flatten(pytree, to_spec) 331 self.assertEqual(result, expected, msg=str([pytree, to_spec, expected])) 332