"""This script updates the file torch/masked/_docs.py that contains the generated doc-strings for various masked operations. The update should be triggered whenever a new masked operation is introduced to torch.masked package. Running the script requires that torch package is functional. """ import os def main() -> None: target = os.path.join("torch", "masked", "_docs.py") try: import torch except ImportError as msg: print(f"Failed to import torch required to build {target}: {msg}") return if os.path.isfile(target): with open(target) as _f: current_content = _f.read() else: current_content = "" _new_content = [] _new_content.append( """\ # -*- coding: utf-8 -*- # This file is generated, do not modify it! # # To update this file, run the update masked docs script as follows: # # python tools/update_masked_docs.py # # The script must be called from an environment where the development # version of torch package can be imported and is functional. # """ ) for func_name in sorted(torch.masked._ops.__all__): func = getattr(torch.masked._ops, func_name) func_doc = torch.masked._generate_docstring(func) # type: ignore[no-untyped-call, attr-defined] _new_content.append(f'{func_name}_docstring = """{func_doc}"""\n') new_content = "\n".join(_new_content) if new_content == current_content: print(f"Nothing to update in {target}") return with open(target, "w") as _f: _f.write(new_content) print(f"Successfully updated {target}") if __name__ == "__main__": main()