xref: /aosp_15_r20/external/pytorch/torch/utils/backcompat/__init__.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from torch._C import _set_backcompat_broadcast_warn
3from torch._C import _get_backcompat_broadcast_warn
4from torch._C import _set_backcompat_keepdim_warn
5from torch._C import _get_backcompat_keepdim_warn
6
7
8class Warning:
9    def __init__(self, setter, getter):
10        self.setter = setter
11        self.getter = getter
12
13    def set_enabled(self, value):
14        self.setter(value)
15
16    def get_enabled(self):
17        return self.getter()
18
19    enabled = property(get_enabled, set_enabled)
20
21broadcast_warning = Warning(_set_backcompat_broadcast_warn, _get_backcompat_broadcast_warn)
22keepdim_warning = Warning(_set_backcompat_keepdim_warn, _get_backcompat_keepdim_warn)
23