1from __future__ import annotations 2 3import os 4import subprocess 5 6from ..util.setting import CompilerType, TestType, TOOLS_FOLDER 7from ..util.utils import print_error, remove_file 8 9 10def get_oss_binary_folder(test_type: TestType) -> str: 11 assert test_type in {TestType.CPP, TestType.PY} 12 # TODO: change the way we get binary file -- binary may not in build/bin ? 13 return os.path.join( 14 get_pytorch_folder(), "build/bin" if test_type == TestType.CPP else "test" 15 ) 16 17 18def get_oss_shared_library() -> list[str]: 19 lib_dir = os.path.join(get_pytorch_folder(), "build", "lib") 20 return [ 21 os.path.join(lib_dir, lib) 22 for lib in os.listdir(lib_dir) 23 if lib.endswith(".dylib") 24 ] 25 26 27def get_oss_binary_file(test_name: str, test_type: TestType) -> str: 28 assert test_type in {TestType.CPP, TestType.PY} 29 binary_folder = get_oss_binary_folder(test_type) 30 binary_file = os.path.join(binary_folder, test_name) 31 if test_type == TestType.PY: 32 # add python to the command so we can directly run the script by using binary_file variable 33 binary_file = "python " + binary_file 34 return binary_file 35 36 37def get_llvm_tool_path() -> str: 38 return os.environ.get( 39 "LLVM_TOOL_PATH", "/usr/local/opt/llvm/bin" 40 ) # set default as llvm path in dev server, on mac the default may be /usr/local/opt/llvm/bin 41 42 43def get_pytorch_folder() -> str: 44 # TOOLS_FOLDER in oss: pytorch/tools/code_coverage 45 return os.path.abspath( 46 os.environ.get( 47 "PYTORCH_FOLDER", os.path.join(TOOLS_FOLDER, os.path.pardir, os.path.pardir) 48 ) 49 ) 50 51 52def detect_compiler_type() -> CompilerType | None: 53 # check if user specifies the compiler type 54 user_specify = os.environ.get("CXX", None) 55 if user_specify: 56 if user_specify in ["clang", "clang++"]: 57 return CompilerType.CLANG 58 elif user_specify in ["gcc", "g++"]: 59 return CompilerType.GCC 60 61 raise RuntimeError(f"User specified compiler is not valid {user_specify}") 62 63 # auto detect 64 auto_detect_result = subprocess.check_output( 65 ["cc", "-v"], stderr=subprocess.STDOUT 66 ).decode("utf-8") 67 if "clang" in auto_detect_result: 68 return CompilerType.CLANG 69 elif "gcc" in auto_detect_result: 70 return CompilerType.GCC 71 raise RuntimeError(f"Auto detected compiler is not valid {auto_detect_result}") 72 73 74def clean_up_gcda() -> None: 75 gcda_files = get_gcda_files() 76 for item in gcda_files: 77 remove_file(item) 78 79 80def get_gcda_files() -> list[str]: 81 folder_has_gcda = os.path.join(get_pytorch_folder(), "build") 82 if os.path.isdir(folder_has_gcda): 83 # TODO use glob 84 # output = glob.glob(f"{folder_has_gcda}/**/*.gcda") 85 output = subprocess.check_output(["find", folder_has_gcda, "-iname", "*.gcda"]) 86 return output.decode("utf-8").split("\n") 87 else: 88 return [] 89 90 91def run_oss_python_test(binary_file: str) -> None: 92 # python test script 93 try: 94 subprocess.check_call( 95 binary_file, shell=True, cwd=get_oss_binary_folder(TestType.PY) 96 ) 97 except subprocess.CalledProcessError: 98 print_error(f"Binary failed to run: {binary_file}") 99