xref: /aosp_15_r20/external/pytorch/torch/fx/experimental/unification/variable.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from contextlib import contextmanager
3from .utils import hashable
4from .dispatch import dispatch
5
6_global_logic_variables = set()  # type: ignore[var-annotated]
7_glv = _global_logic_variables
8
9
10class Var:
11    """ Logic Variable """
12
13    _id = 1
14
15    def __new__(cls, *token):
16        if len(token) == 0:
17            token = f"_{Var._id}"  # type: ignore[assignment]
18            Var._id += 1
19        elif len(token) == 1:
20            token = token[0]
21
22        obj = object.__new__(cls)
23        obj.token = token  # type: ignore[attr-defined]
24        return obj
25
26    def __str__(self):
27        return "~" + str(self.token)  # type: ignore[attr-defined]
28    __repr__ = __str__
29
30    def __eq__(self, other):
31        return type(self) == type(other) and self.token == other.token  # type: ignore[attr-defined]
32
33    def __hash__(self):
34        return hash((type(self), self.token))  # type: ignore[attr-defined]
35
36
37def var():
38    return lambda *args: Var(*args)
39
40
41def vars():
42    return lambda n: [var() for i in range(n)]
43
44
45@dispatch(Var)
46def isvar(v):
47    return True
48
49isvar
50
51
52@dispatch(object)  # type: ignore[no-redef]
53def isvar(o):
54    return not not _glv and hashable(o) and o in _glv
55
56
57@contextmanager
58def variables(*variables):
59    """
60    Context manager for logic variables
61
62    Example:
63        >>> # xdoctest: +SKIP("undefined vars")
64        >>> from __future__ import with_statement
65        >>> with variables(1):
66        ...     print(isvar(1))
67        True
68        >>> print(isvar(1))
69        False
70        >>> # Normal approach
71        >>> from unification import unify
72        >>> x = var('x')
73        >>> unify(x, 1)
74        {~x: 1}
75        >>> # Context Manager approach
76        >>> with variables('x'):
77        ...     print(unify('x', 1))
78        {'x': 1}
79    """
80    old_global_logic_variables = _global_logic_variables.copy()
81    _global_logic_variables.update(set(variables))
82    try:
83        yield
84    finally:
85        _global_logic_variables.clear()
86        _global_logic_variables.update(old_global_logic_variables)
87