xref: /aosp_15_r20/external/pytorch/tools/update_masked_docs.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker"""This script updates the file torch/masked/_docs.py that contains
2*da0073e9SAndroid Build Coastguard Workerthe generated doc-strings for various masked operations. The update
3*da0073e9SAndroid Build Coastguard Workershould be triggered whenever a new masked operation is introduced to
4*da0073e9SAndroid Build Coastguard Workertorch.masked package. Running the script requires that torch package
5*da0073e9SAndroid Build Coastguard Workeris functional.
6*da0073e9SAndroid Build Coastguard Worker"""
7*da0073e9SAndroid Build Coastguard Worker
8*da0073e9SAndroid Build Coastguard Workerimport os
9*da0073e9SAndroid Build Coastguard Worker
10*da0073e9SAndroid Build Coastguard Worker
11*da0073e9SAndroid Build Coastguard Workerdef main() -> None:
12*da0073e9SAndroid Build Coastguard Worker    target = os.path.join("torch", "masked", "_docs.py")
13*da0073e9SAndroid Build Coastguard Worker
14*da0073e9SAndroid Build Coastguard Worker    try:
15*da0073e9SAndroid Build Coastguard Worker        import torch
16*da0073e9SAndroid Build Coastguard Worker    except ImportError as msg:
17*da0073e9SAndroid Build Coastguard Worker        print(f"Failed to import torch required to build {target}: {msg}")
18*da0073e9SAndroid Build Coastguard Worker        return
19*da0073e9SAndroid Build Coastguard Worker
20*da0073e9SAndroid Build Coastguard Worker    if os.path.isfile(target):
21*da0073e9SAndroid Build Coastguard Worker        with open(target) as _f:
22*da0073e9SAndroid Build Coastguard Worker            current_content = _f.read()
23*da0073e9SAndroid Build Coastguard Worker    else:
24*da0073e9SAndroid Build Coastguard Worker        current_content = ""
25*da0073e9SAndroid Build Coastguard Worker
26*da0073e9SAndroid Build Coastguard Worker    _new_content = []
27*da0073e9SAndroid Build Coastguard Worker    _new_content.append(
28*da0073e9SAndroid Build Coastguard Worker        """\
29*da0073e9SAndroid Build Coastguard Worker# -*- coding: utf-8 -*-
30*da0073e9SAndroid Build Coastguard Worker# This file is generated, do not modify it!
31*da0073e9SAndroid Build Coastguard Worker#
32*da0073e9SAndroid Build Coastguard Worker# To update this file, run the update masked docs script as follows:
33*da0073e9SAndroid Build Coastguard Worker#
34*da0073e9SAndroid Build Coastguard Worker#   python tools/update_masked_docs.py
35*da0073e9SAndroid Build Coastguard Worker#
36*da0073e9SAndroid Build Coastguard Worker# The script must be called from an environment where the development
37*da0073e9SAndroid Build Coastguard Worker# version of torch package can be imported and is functional.
38*da0073e9SAndroid Build Coastguard Worker#
39*da0073e9SAndroid Build Coastguard Worker"""
40*da0073e9SAndroid Build Coastguard Worker    )
41*da0073e9SAndroid Build Coastguard Worker
42*da0073e9SAndroid Build Coastguard Worker    for func_name in sorted(torch.masked._ops.__all__):
43*da0073e9SAndroid Build Coastguard Worker        func = getattr(torch.masked._ops, func_name)
44*da0073e9SAndroid Build Coastguard Worker        func_doc = torch.masked._generate_docstring(func)  # type: ignore[no-untyped-call, attr-defined]
45*da0073e9SAndroid Build Coastguard Worker        _new_content.append(f'{func_name}_docstring = """{func_doc}"""\n')
46*da0073e9SAndroid Build Coastguard Worker
47*da0073e9SAndroid Build Coastguard Worker    new_content = "\n".join(_new_content)
48*da0073e9SAndroid Build Coastguard Worker
49*da0073e9SAndroid Build Coastguard Worker    if new_content == current_content:
50*da0073e9SAndroid Build Coastguard Worker        print(f"Nothing to update in {target}")
51*da0073e9SAndroid Build Coastguard Worker        return
52*da0073e9SAndroid Build Coastguard Worker
53*da0073e9SAndroid Build Coastguard Worker    with open(target, "w") as _f:
54*da0073e9SAndroid Build Coastguard Worker        _f.write(new_content)
55*da0073e9SAndroid Build Coastguard Worker
56*da0073e9SAndroid Build Coastguard Worker    print(f"Successfully updated {target}")
57*da0073e9SAndroid Build Coastguard Worker
58*da0073e9SAndroid Build Coastguard Worker
59*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
60*da0073e9SAndroid Build Coastguard Worker    main()
61