1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-strict 8 9import enum 10import json 11from dataclasses import fields, is_dataclass 12from typing import Any, Dict, get_args, get_origin, get_type_hints, Union 13 14 15class _DataclassEncoder(json.JSONEncoder): 16 # pyre-ignore 17 def default(self, o: Any) -> Any: 18 if is_dataclass(o): 19 props = {} 20 for field in fields(o): 21 props[field.name] = getattr(o, field.name) 22 origin = get_origin(get_type_hints(type(o))[field.name]) 23 if isinstance(field.type, str) and origin == Union: 24 props[f"{field.name}_type"] = type(getattr(o, field.name)).__name__ 25 return props 26 27 if isinstance(o, bytes): 28 return list(o) 29 30 return super().default(o) 31 32 33# Dataclass Decoder 34# pyre-ignore 35def _is_optional(T: Any) -> bool: 36 return ( 37 get_origin(T) is Union 38 and len(get_args(T)) > 0 39 and isinstance(None, get_args(T)[-1]) 40 ) 41 42 43# pyre-ignore 44def _is_strict_union(T: Any, cls: Any, key: str) -> bool: 45 return isinstance(T, str) and get_origin(get_type_hints(cls)[key]) is Union 46 47 48# pyre-ignore 49def _get_class_from_union(json_dict: Dict[str, Any], key: str, cls: Any) -> Any: 50 """Search through all possible types in the Union and select the type we 51 want to unpack (note in the serialization of a PyObject to JSON, 52 the type we want to unpack is keyed by f"{field.name}_type"). 53 """ 54 _type = json_dict[key + "_type"] 55 res = [x for x in get_args(get_type_hints(cls)[key]) if x.__name__ == _type] 56 return res[0] 57 58 59# pyre-ignore 60def _json_to_dataclass(json_dict: Dict[str, Any], cls: Any = None) -> Any: 61 """Initializes a dataclass given a dictionary loaded from a json, 62 `json_dict`, and the expected class, `cls`, by iterating through the fields 63 of the class and retrieving the data for each. If there is a field that is 64 missing in the data, and that field is not the Optional type, 65 `_json_to_dataclass` raises a TypeError. 66 67 Args: 68 `json_dict` : Dictionary formatted to represent a class, where fields are keys in the dictionary, 69 and values are values with the required type (as outlined in the dataclass definition). If a field is 70 specified to be another dataclass, the value will be another dictionary. See example below: 71 72 SAMPLE JSON: 73 {field1 : v1, inner_class : {field2_1: v2_1, field2_2: v2_1}}. 74 75 `cls` : The class that we should be unpacking from the given dictionary 76 (gives us an idea of what fields and values will be present in `json_dict`) 77 78 SAMPLE CLASSES for Above JSON: 79 @dataclass 80 class AnotherDataClass 81 field2_1: int 82 field2_2: str 83 84 @dataclass 85 class Example 86 field1 : str 87 inner_class: AnotherDataClass 88 89 Returns: An initialized PyObject of class: `cls`, given the data from `json_dict`. 90 """ 91 if not is_dataclass(cls) or is_dataclass(json_dict): 92 return json_dict 93 94 # initialize dataclass by iterating through all required fields 95 cls_flds = fields(cls) 96 data = {} 97 for field in cls_flds: 98 key = field.name 99 T = field.type 100 101 if _is_optional(T): 102 T = get_args(T)[0] 103 value = json_dict.get(key, None) 104 elif _is_strict_union(T, cls, key): 105 # If the field is a Union type, we determine exactly what type we 106 # are trying to initialize by calling `_get_class_from_union`, and 107 # then make a recursive call construct this new class 108 _cls = _get_class_from_union(json_dict, key, cls) 109 data[key] = _json_to_dataclass(json_dict[key], _cls) 110 continue 111 else: 112 try: 113 value = json_dict[key] 114 except KeyError: 115 raise TypeError( 116 f"Invalid Buffer. Received no value for field: {key}, but {key} : {T} is not an Optional type." 117 ) 118 119 if value is None: 120 data[key] = None 121 continue 122 123 if is_dataclass(T): 124 data[key] = _json_to_dataclass(value, T) 125 continue 126 127 if get_origin(T) is list: 128 T = get_args(T)[0] 129 data[key] = [_json_to_dataclass(e, T) for e in value] 130 continue 131 132 # If T is a Union, then check which type in the Union it is and initialize. 133 # eg. Double type in schema.py 134 if get_origin(T) is Union: 135 res = [x for x in get_args(get_type_hints(cls)[key]) if x == type(value)] 136 data[key] = res[0](value) 137 continue 138 139 # If T is an enum then lookup the value in the enum otherwise try to 140 # cast value to whatever type is required 141 if isinstance(T, enum.EnumMeta): 142 data[key] = T[value] 143 else: 144 data[key] = T(value) 145 return cls(**data) 146