xref: /aosp_15_r20/external/pytorch/tools/update_masked_docs.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1"""This script updates the file torch/masked/_docs.py that contains
2the generated doc-strings for various masked operations. The update
3should be triggered whenever a new masked operation is introduced to
4torch.masked package. Running the script requires that torch package
5is functional.
6"""
7
8import os
9
10
11def main() -> None:
12    target = os.path.join("torch", "masked", "_docs.py")
13
14    try:
15        import torch
16    except ImportError as msg:
17        print(f"Failed to import torch required to build {target}: {msg}")
18        return
19
20    if os.path.isfile(target):
21        with open(target) as _f:
22            current_content = _f.read()
23    else:
24        current_content = ""
25
26    _new_content = []
27    _new_content.append(
28        """\
29# -*- coding: utf-8 -*-
30# This file is generated, do not modify it!
31#
32# To update this file, run the update masked docs script as follows:
33#
34#   python tools/update_masked_docs.py
35#
36# The script must be called from an environment where the development
37# version of torch package can be imported and is functional.
38#
39"""
40    )
41
42    for func_name in sorted(torch.masked._ops.__all__):
43        func = getattr(torch.masked._ops, func_name)
44        func_doc = torch.masked._generate_docstring(func)  # type: ignore[no-untyped-call, attr-defined]
45        _new_content.append(f'{func_name}_docstring = """{func_doc}"""\n')
46
47    new_content = "\n".join(_new_content)
48
49    if new_content == current_content:
50        print(f"Nothing to update in {target}")
51        return
52
53    with open(target, "w") as _f:
54        _f.write(new_content)
55
56    print(f"Successfully updated {target}")
57
58
59if __name__ == "__main__":
60    main()
61