xref: /aosp_15_r20/external/executorch/exir/serde/union.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-ignore-all-errors
8
9import functools
10from dataclasses import fields
11from typing import Hashable, Set
12
13
14class _UnionTag(str):
15    _cls: Hashable
16
17    @staticmethod
18    def create(t, cls):
19        tag = _UnionTag(t)
20        assert not hasattr(tag, "_cls")
21        tag._cls = cls
22        return tag
23
24    def __eq__(self, cmp) -> bool:
25        assert isinstance(cmp, str)
26        other = str(cmp)
27        assert other in _get_field_names(
28            self._cls
29        ), f"{other} is not a valid tag for {self._cls}. Available tags: {_get_field_names(self._cls)}"
30        return str(self) == other
31
32    def __hash__(self):
33        return hash(str(self))
34
35
36@functools.lru_cache(maxsize=None)
37def _get_field_names(cls) -> Set[str]:
38    return {f.name for f in fields(cls)}
39
40
41class _Union:
42    _type: _UnionTag
43
44    @classmethod
45    def create(cls, **kwargs):
46        assert len(kwargs) == 1
47        obj = cls(**{**{f.name: None for f in fields(cls)}, **kwargs})  # type: ignore[arg-type]
48        obj._type = _UnionTag.create(next(iter(kwargs.keys())), cls)
49        return obj
50
51    def __post_init__(self):
52        assert not any(f.name in ("type", "_type", "create", "value") for f in fields(self))  # type: ignore[arg-type, misc]
53
54    @property
55    def type(self) -> str:
56        try:
57            return self._type
58        except AttributeError as e:
59            raise RuntimeError(
60                f"Please use {type(self).__name__}.create to instantiate the union type."
61            ) from e
62
63    @property
64    def value(self):
65        return getattr(self, self.type)
66
67    def __getattribute__(self, name):
68        attr = super().__getattribute__(name)
69        if attr is None and name in _get_field_names(type(self)) and name != self.type:  # type: ignore[arg-type]
70            raise AttributeError(f"Field {name} is not set.")
71        return attr
72
73    def __str__(self):
74        return self.__repr__()
75
76    def __repr__(self):
77        return f"{type(self).__name__}({self.type}={getattr(self, self.type)})"
78