xref: /aosp_15_r20/external/pytorch/torch/fx/experimental/unification/more.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from .core import unify, reify  # type: ignore[attr-defined]
3from .dispatch import dispatch
4
5
6def unifiable(cls):
7    """ Register standard unify and reify operations on class
8    This uses the type and __dict__ or __slots__ attributes to define the
9    nature of the term
10    See Also:
11    >>> # xdoctest: +SKIP
12    >>> class A(object):
13    ...     def __init__(self, a, b):
14    ...         self.a = a
15    ...         self.b = b
16    >>> unifiable(A)
17    <class 'unification.more.A'>
18    >>> x = var('x')
19    >>> a = A(1, 2)
20    >>> b = A(1, x)
21    >>> unify(a, b, {})
22    {~x: 2}
23    """
24    _unify.add((cls, cls, dict), unify_object)
25    _reify.add((cls, dict), reify_object)
26
27    return cls
28
29
30#########
31# Reify #
32#########
33
34
35def reify_object(o, s):
36    """ Reify a Python object with a substitution
37    >>> # xdoctest: +SKIP
38    >>> class Foo(object):
39    ...     def __init__(self, a, b):
40    ...         self.a = a
41    ...         self.b = b
42    ...     def __str__(self):
43    ...         return "Foo(%s, %s)"%(str(self.a), str(self.b))
44    >>> x = var('x')
45    >>> f = Foo(1, x)
46    >>> print(f)
47    Foo(1, ~x)
48    >>> print(reify_object(f, {x: 2}))
49    Foo(1, 2)
50    """
51    if hasattr(o, '__slots__'):
52        return _reify_object_slots(o, s)
53    else:
54        return _reify_object_dict(o, s)
55
56
57def _reify_object_dict(o, s):
58    obj = object.__new__(type(o))
59    d = reify(o.__dict__, s)
60    if d == o.__dict__:
61        return o
62    obj.__dict__.update(d)
63    return obj
64
65
66def _reify_object_slots(o, s):
67    attrs = [getattr(o, attr) for attr in o.__slots__]
68    new_attrs = reify(attrs, s)
69    if attrs == new_attrs:
70        return o
71    else:
72        newobj = object.__new__(type(o))
73        for slot, attr in zip(o.__slots__, new_attrs):
74            setattr(newobj, slot, attr)
75        return newobj
76
77
78@dispatch(slice, dict)
79def _reify(o, s):
80    """ Reify a Python ``slice`` object """
81    return slice(*reify((o.start, o.stop, o.step), s))
82
83
84#########
85# Unify #
86#########
87
88
89def unify_object(u, v, s):
90    """ Unify two Python objects
91    Unifies their type and ``__dict__`` attributes
92    >>> # xdoctest: +SKIP
93    >>> class Foo(object):
94    ...     def __init__(self, a, b):
95    ...         self.a = a
96    ...         self.b = b
97    ...     def __str__(self):
98    ...         return "Foo(%s, %s)"%(str(self.a), str(self.b))
99    >>> x = var('x')
100    >>> f = Foo(1, x)
101    >>> g = Foo(1, 2)
102    >>> unify_object(f, g, {})
103    {~x: 2}
104    """
105    if type(u) != type(v):
106        return False
107    if hasattr(u, '__slots__'):
108        return unify([getattr(u, slot) for slot in u.__slots__],
109                     [getattr(v, slot) for slot in v.__slots__],
110                     s)
111    else:
112        return unify(u.__dict__, v.__dict__, s)
113
114
115@dispatch(slice, slice, dict)
116def _unify(u, v, s):
117    """ Unify a Python ``slice`` object """
118    return unify((u.start, u.stop, u.step), (v.start, v.stop, v.step), s)
119