xref: /aosp_15_r20/external/pytorch/torch/utils/_stats.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2# NOTE! PLEASE KEEP THIS FILE *FREE* OF TORCH DEPS! IT SHOULD BE IMPORTABLE ANYWHERE.
3# IF YOU FEEL AN OVERWHELMING URGE TO ADD A TORCH DEP, MAKE A TRAMPOLINE FILE A LA torch._dynamo.utils
4# AND SCRUB AWAY TORCH NOTIONS THERE.
5import collections
6import functools
7from typing import OrderedDict
8
9simple_call_counter: OrderedDict[str, int] = collections.OrderedDict()
10
11def count_label(label):
12    prev = simple_call_counter.setdefault(label, 0)
13    simple_call_counter[label] = prev + 1
14
15def count(fn):
16    @functools.wraps(fn)
17    def wrapper(*args, **kwargs):
18        if fn.__qualname__ not in simple_call_counter:
19            simple_call_counter[fn.__qualname__] = 0
20        simple_call_counter[fn.__qualname__] = simple_call_counter[fn.__qualname__] + 1
21        return fn(*args, **kwargs)
22    return wrapper
23