xref: /aosp_15_r20/external/pytorch/mypy_plugins/sympy_mypy_plugin.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from mypy.plugin import Plugin
2from mypy.plugins.common import add_attribute_to_class
3from mypy.types import NoneType, UnionType
4
5
6class SympyPlugin(Plugin):
7    def get_base_class_hook(self, fullname: str):
8        if fullname == "sympy.core.basic.Basic":
9            return add_assumptions
10        return None
11
12
13def add_assumptions(ctx) -> None:
14    # Generated by list(sys.modules['sympy.core.assumptions']._assume_defined)
15    # (do not import sympy to speedup mypy plugin load time)
16    assumptions = [
17        "hermitian",
18        "prime",
19        "noninteger",
20        "negative",
21        "antihermitian",
22        "infinite",
23        "finite",
24        "irrational",
25        "extended_positive",
26        "nonpositive",
27        "odd",
28        "algebraic",
29        "integer",
30        "rational",
31        "extended_real",
32        "nonnegative",
33        "transcendental",
34        "extended_nonzero",
35        "extended_negative",
36        "composite",
37        "complex",
38        "imaginary",
39        "nonzero",
40        "zero",
41        "even",
42        "positive",
43        "polar",
44        "extended_nonpositive",
45        "extended_nonnegative",
46        "real",
47        "commutative",
48    ]
49    for a in assumptions:
50        add_attribute_to_class(
51            ctx.api,
52            ctx.cls,
53            f"is_{a}",
54            UnionType([ctx.api.named_type("builtins.bool"), NoneType()]),
55        )
56
57
58def plugin(version: str):
59    return SympyPlugin
60