xref: /aosp_15_r20/external/executorch/exir/_serialize/_dataclass.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-strict
8
9import enum
10import json
11from dataclasses import fields, is_dataclass
12from typing import Any, Dict, get_args, get_origin, get_type_hints, Union
13
14
15class _DataclassEncoder(json.JSONEncoder):
16    # pyre-ignore
17    def default(self, o: Any) -> Any:
18        if is_dataclass(o):
19            props = {}
20            for field in fields(o):
21                props[field.name] = getattr(o, field.name)
22                origin = get_origin(get_type_hints(type(o))[field.name])
23                if isinstance(field.type, str) and origin == Union:
24                    props[f"{field.name}_type"] = type(getattr(o, field.name)).__name__
25            return props
26
27        if isinstance(o, bytes):
28            return list(o)
29
30        return super().default(o)
31
32
33# Dataclass Decoder
34# pyre-ignore
35def _is_optional(T: Any) -> bool:
36    return (
37        get_origin(T) is Union
38        and len(get_args(T)) > 0
39        and isinstance(None, get_args(T)[-1])
40    )
41
42
43# pyre-ignore
44def _is_strict_union(T: Any, cls: Any, key: str) -> bool:
45    return isinstance(T, str) and get_origin(get_type_hints(cls)[key]) is Union
46
47
48# pyre-ignore
49def _get_class_from_union(json_dict: Dict[str, Any], key: str, cls: Any) -> Any:
50    """Search through all possible types in the Union and select the type we
51    want to unpack (note in the serialization of a PyObject to JSON,
52    the type we want to unpack is keyed by f"{field.name}_type").
53    """
54    _type = json_dict[key + "_type"]
55    res = [x for x in get_args(get_type_hints(cls)[key]) if x.__name__ == _type]
56    return res[0]
57
58
59# pyre-ignore
60def _json_to_dataclass(json_dict: Dict[str, Any], cls: Any = None) -> Any:
61    """Initializes a dataclass given a dictionary loaded from a json,
62    `json_dict`, and the expected class, `cls`, by iterating through the fields
63    of the class and retrieving the data for each. If there is a field that is
64    missing in the data, and that field is not the Optional type,
65    `_json_to_dataclass` raises a TypeError.
66
67    Args:
68        `json_dict` : Dictionary formatted to represent a class, where fields are keys in the dictionary,
69        and values are values with the required type (as outlined in the dataclass definition). If a field is
70        specified to be another dataclass, the value will be another dictionary. See example below:
71
72        SAMPLE JSON:
73        {field1 : v1, inner_class : {field2_1: v2_1, field2_2: v2_1}}.
74
75        `cls` : The class that we should be unpacking from the given dictionary
76                (gives us an idea of what fields and values will be present in `json_dict`)
77
78        SAMPLE CLASSES for Above JSON:
79        @dataclass
80        class AnotherDataClass
81            field2_1: int
82            field2_2: str
83
84        @dataclass
85        class Example
86            field1 : str
87            inner_class: AnotherDataClass
88
89    Returns: An initialized PyObject of class: `cls`, given the data from `json_dict`.
90    """
91    if not is_dataclass(cls) or is_dataclass(json_dict):
92        return json_dict
93
94    # initialize dataclass by iterating through all required fields
95    cls_flds = fields(cls)
96    data = {}
97    for field in cls_flds:
98        key = field.name
99        T = field.type
100
101        if _is_optional(T):
102            T = get_args(T)[0]
103            value = json_dict.get(key, None)
104        elif _is_strict_union(T, cls, key):
105            # If the field is a Union type, we determine exactly what type we
106            # are trying to initialize by calling `_get_class_from_union`, and
107            # then make a recursive call construct this new class
108            _cls = _get_class_from_union(json_dict, key, cls)
109            data[key] = _json_to_dataclass(json_dict[key], _cls)
110            continue
111        else:
112            try:
113                value = json_dict[key]
114            except KeyError:
115                raise TypeError(
116                    f"Invalid Buffer. Received no value for field: {key}, but {key} : {T} is not an Optional type."
117                )
118
119        if value is None:
120            data[key] = None
121            continue
122
123        if is_dataclass(T):
124            data[key] = _json_to_dataclass(value, T)
125            continue
126
127        if get_origin(T) is list:
128            T = get_args(T)[0]
129            data[key] = [_json_to_dataclass(e, T) for e in value]
130            continue
131
132        # If T is a Union, then check which type in the Union it is and initialize.
133        # eg. Double type in schema.py
134        if get_origin(T) is Union:
135            res = [x for x in get_args(get_type_hints(cls)[key]) if x == type(value)]
136            data[key] = res[0](value)
137            continue
138
139        # If T is an enum then lookup the value in the enum otherwise try to
140        # cast value to whatever type is required
141        if isinstance(T, enum.EnumMeta):
142            data[key] = T[value]
143        else:
144            data[key] = T(value)
145    return cls(**data)
146