xref: /aosp_15_r20/external/pytorch/torch/masked/__init__.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from torch.masked._ops import (
2    _canonical_dim,
3    _combine_input_and_mask,
4    _generate_docstring,
5    _input_mask,
6    _output_mask,
7    _reduction_identity,
8    _where,
9    amax,
10    amin,
11    argmax,
12    argmin,
13    cumprod,
14    cumsum,
15    log_softmax,
16    logaddexp,
17    logsumexp,
18    mean,
19    median,
20    norm,
21    normalize,
22    prod,
23    softmax,
24    softmin,
25    std,
26    sum,
27    var,
28)
29from torch.masked.maskedtensor.core import is_masked_tensor, MaskedTensor
30from torch.masked.maskedtensor.creation import as_masked_tensor, masked_tensor
31
32
33__all__ = [
34    "amax",
35    "amin",
36    "argmax",
37    "argmin",
38    "as_masked_tensor",
39    "cumprod",
40    "cumsum",
41    "is_masked_tensor",
42    "log_softmax",
43    "logaddexp",
44    "logsumexp",
45    "masked_tensor",
46    "MaskedTensor",
47    "mean",
48    "median",
49    "norm",
50    "normalize",
51    "prod",
52    "softmax",
53    "softmin",
54    "std",
55    "sum",
56    "var",
57]
58