1"""This script updates the file torch/masked/_docs.py that contains 2the generated doc-strings for various masked operations. The update 3should be triggered whenever a new masked operation is introduced to 4torch.masked package. Running the script requires that torch package 5is functional. 6""" 7 8import os 9 10 11def main() -> None: 12 target = os.path.join("torch", "masked", "_docs.py") 13 14 try: 15 import torch 16 except ImportError as msg: 17 print(f"Failed to import torch required to build {target}: {msg}") 18 return 19 20 if os.path.isfile(target): 21 with open(target) as _f: 22 current_content = _f.read() 23 else: 24 current_content = "" 25 26 _new_content = [] 27 _new_content.append( 28 """\ 29# -*- coding: utf-8 -*- 30# This file is generated, do not modify it! 31# 32# To update this file, run the update masked docs script as follows: 33# 34# python tools/update_masked_docs.py 35# 36# The script must be called from an environment where the development 37# version of torch package can be imported and is functional. 38# 39""" 40 ) 41 42 for func_name in sorted(torch.masked._ops.__all__): 43 func = getattr(torch.masked._ops, func_name) 44 func_doc = torch.masked._generate_docstring(func) # type: ignore[no-untyped-call, attr-defined] 45 _new_content.append(f'{func_name}_docstring = """{func_doc}"""\n') 46 47 new_content = "\n".join(_new_content) 48 49 if new_content == current_content: 50 print(f"Nothing to update in {target}") 51 return 52 53 with open(target, "w") as _f: 54 _f.write(new_content) 55 56 print(f"Successfully updated {target}") 57 58 59if __name__ == "__main__": 60 main() 61