1# mypy: allow-untyped-defs 2import torch 3from typing import TypeVar 4 5T = TypeVar('T') 6 7# returns if all are the same mode 8def all_same_mode(modes): 9 return all(tuple(mode == modes[0] for mode in modes)) 10 11no_dispatch = torch._C._DisableTorchDispatch 12