1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-ignore-all-errors 8 9import functools 10from dataclasses import fields 11from typing import Hashable, Set 12 13 14class _UnionTag(str): 15 _cls: Hashable 16 17 @staticmethod 18 def create(t, cls): 19 tag = _UnionTag(t) 20 assert not hasattr(tag, "_cls") 21 tag._cls = cls 22 return tag 23 24 def __eq__(self, cmp) -> bool: 25 assert isinstance(cmp, str) 26 other = str(cmp) 27 assert other in _get_field_names( 28 self._cls 29 ), f"{other} is not a valid tag for {self._cls}. Available tags: {_get_field_names(self._cls)}" 30 return str(self) == other 31 32 def __hash__(self): 33 return hash(str(self)) 34 35 36@functools.lru_cache(maxsize=None) 37def _get_field_names(cls) -> Set[str]: 38 return {f.name for f in fields(cls)} 39 40 41class _Union: 42 _type: _UnionTag 43 44 @classmethod 45 def create(cls, **kwargs): 46 assert len(kwargs) == 1 47 obj = cls(**{**{f.name: None for f in fields(cls)}, **kwargs}) # type: ignore[arg-type] 48 obj._type = _UnionTag.create(next(iter(kwargs.keys())), cls) 49 return obj 50 51 def __post_init__(self): 52 assert not any(f.name in ("type", "_type", "create", "value") for f in fields(self)) # type: ignore[arg-type, misc] 53 54 @property 55 def type(self) -> str: 56 try: 57 return self._type 58 except AttributeError as e: 59 raise RuntimeError( 60 f"Please use {type(self).__name__}.create to instantiate the union type." 61 ) from e 62 63 @property 64 def value(self): 65 return getattr(self, self.type) 66 67 def __getattribute__(self, name): 68 attr = super().__getattribute__(name) 69 if attr is None and name in _get_field_names(type(self)) and name != self.type: # type: ignore[arg-type] 70 raise AttributeError(f"Field {name} is not set.") 71 return attr 72 73 def __str__(self): 74 return self.__repr__() 75 76 def __repr__(self): 77 return f"{type(self).__name__}({self.type}={getattr(self, self.type)})" 78