1"""Setup TensorFlow as external dependency""" 2 3_TF_HEADER_DIR = "TF_HEADER_DIR" 4_TF_SHARED_LIBRARY_DIR = "TF_SHARED_LIBRARY_DIR" 5_TF_SHARED_LIBRARY_NAME = "TF_SHARED_LIBRARY_NAME" 6 7def _tpl(repository_ctx, tpl, substitutions = {}, out = None): 8 if not out: 9 out = tpl 10 repository_ctx.template( 11 out, 12 Label("//third_party/tensorflow:%s.tpl" % tpl), 13 substitutions, 14 ) 15 16def _fail(msg): 17 """Output failure message when auto configuration fails.""" 18 red = "\033[0;31m" 19 no_color = "\033[0m" 20 fail("%sPython Configuration Error:%s %s\n" % (red, no_color, msg)) 21 22def _is_windows(repository_ctx): 23 """Returns true if the host operating system is windows.""" 24 os_name = repository_ctx.os.name.lower() 25 if os_name.find("windows") != -1: 26 return True 27 return False 28 29def _execute( 30 repository_ctx, 31 cmdline, 32 error_msg = None, 33 error_details = None, 34 empty_stdout_fine = False): 35 """Executes an arbitrary shell command. 36 37 Helper for executes an arbitrary shell command. 38 39 Args: 40 repository_ctx: the repository_ctx object. 41 cmdline: list of strings, the command to execute. 42 error_msg: string, a summary of the error if the command fails. 43 error_details: string, details about the error or steps to fix it. 44 empty_stdout_fine: bool, if True, an empty stdout result is fine, otherwise 45 it's an error. 46 47 Returns: 48 The result of repository_ctx.execute(cmdline). 49 """ 50 result = repository_ctx.execute(cmdline) 51 if result.stderr or not (empty_stdout_fine or result.stdout): 52 _fail("\n".join([ 53 error_msg.strip() if error_msg else "Repository command failed", 54 result.stderr.strip(), 55 error_details if error_details else "", 56 ])) 57 return result 58 59def _read_dir(repository_ctx, src_dir): 60 """Returns a string with all files in a directory. 61 62 Finds all files inside a directory, traversing subfolders and following 63 symlinks. The returned string contains the full path of all files 64 separated by line breaks. 65 66 Args: 67 repository_ctx: the repository_ctx object. 68 src_dir: directory to find files from. 69 70 Returns: 71 A string of all files inside the given dir. 72 """ 73 if _is_windows(repository_ctx): 74 src_dir = src_dir.replace("/", "\\") 75 find_result = _execute( 76 repository_ctx, 77 ["cmd.exe", "/c", "dir", src_dir, "/b", "/s", "/a-d"], 78 empty_stdout_fine = True, 79 ) 80 81 # src_files will be used in genrule.outs where the paths must 82 # use forward slashes. 83 result = find_result.stdout.replace("\\", "/") 84 else: 85 find_result = _execute( 86 repository_ctx, 87 ["find", src_dir, "-follow", "-type", "f"], 88 empty_stdout_fine = True, 89 ) 90 result = find_result.stdout 91 return result 92 93def _genrule(genrule_name, command, outs): 94 """Returns a string with a genrule. 95 96 Genrule executes the given command and produces the given outputs. 97 98 Args: 99 genrule_name: A unique name for genrule target. 100 command: The command to run. 101 outs: A list of files generated by this rule. 102 103 Returns: 104 A genrule target. 105 """ 106 return ( 107 "genrule(\n" + 108 ' name = "' + 109 genrule_name + '",\n' + 110 " outs = [\n" + 111 outs + 112 "\n ],\n" + 113 ' cmd = """\n' + 114 command + 115 '\n """,\n' + 116 ")\n" 117 ) 118 119def _norm_path(path): 120 """Returns a path with '/' and remove the trailing slash.""" 121 path = path.replace("\\", "/") 122 if path[-1] == "/": 123 path = path[:-1] 124 return path 125 126def _symlink_genrule_for_dir( 127 repository_ctx, 128 src_dir, 129 dest_dir, 130 genrule_name, 131 src_files = [], 132 dest_files = [], 133 tf_pip_dir_rename_pair = []): 134 """Returns a genrule to symlink(or copy if on Windows) a set of files. 135 136 If src_dir is passed, files will be read from the given directory; otherwise 137 we assume files are in src_files and dest_files. 138 139 Args: 140 repository_ctx: the repository_ctx object. 141 src_dir: source directory. 142 dest_dir: directory to create symlink in. 143 genrule_name: genrule name. 144 src_files: list of source files instead of src_dir. 145 dest_files: list of corresonding destination files. 146 tf_pip_dir_rename_pair: list of the pair of tf pip parent directory to 147 replace. For example, in TF pip package, the source code is under 148 "tensorflow_core", and we might want to replace it with 149 "tensorflow" to match the header includes. 150 Returns: 151 genrule target that creates the symlinks. 152 """ 153 154 # Check that tf_pip_dir_rename_pair has the right length 155 tf_pip_dir_rename_pair_len = len(tf_pip_dir_rename_pair) 156 if tf_pip_dir_rename_pair_len != 0 and tf_pip_dir_rename_pair_len != 2: 157 _fail("The size of argument tf_pip_dir_rename_pair should be either 0 or 2, but %d is given." % tf_pip_dir_rename_pair_len) 158 159 if src_dir != None: 160 src_dir = _norm_path(src_dir) 161 dest_dir = _norm_path(dest_dir) 162 files = "\n".join(sorted(_read_dir(repository_ctx, src_dir).splitlines())) 163 164 # Create a list with the src_dir stripped to use for outputs. 165 if tf_pip_dir_rename_pair_len: 166 dest_files = files.replace(src_dir, "").replace(tf_pip_dir_rename_pair[0], tf_pip_dir_rename_pair[1]).splitlines() 167 else: 168 dest_files = files.replace(src_dir, "").splitlines() 169 src_files = files.splitlines() 170 command = [] 171 outs = [] 172 for i in range(len(dest_files)): 173 if dest_files[i] != "": 174 # If we have only one file to link we do not want to use the dest_dir, as 175 # $(@D) will include the full path to the file. 176 dest = "$(@D)/" + dest_dir + dest_files[i] if len(dest_files) != 1 else "$(@D)/" + dest_files[i] 177 178 # Copy the headers to create a sandboxable setup. 179 cmd = "cp -f" 180 command.append(cmd + ' "%s" "%s"' % (src_files[i], dest)) 181 outs.append(' "' + dest_dir + dest_files[i] + '",') 182 dest_dir = "abc" 183 genrule = _genrule( 184 genrule_name, 185 " && ".join(command), 186 "\n".join(outs), 187 ) 188 return genrule 189 190def _tf_pip_impl(repository_ctx): 191 tf_header_dir = repository_ctx.os.environ[_TF_HEADER_DIR] 192 tf_header_rule = _symlink_genrule_for_dir( 193 repository_ctx, 194 tf_header_dir, 195 "include", 196 "tf_header_include", 197 tf_pip_dir_rename_pair = ["tensorflow_core", "tensorflow"], 198 ) 199 200 tf_shared_library_dir = repository_ctx.os.environ[_TF_SHARED_LIBRARY_DIR] 201 tf_shared_library_name = repository_ctx.os.environ[_TF_SHARED_LIBRARY_NAME] 202 tf_shared_library_path = "%s/%s" % (tf_shared_library_dir, tf_shared_library_name) 203 tf_shared_library_rule = _symlink_genrule_for_dir( 204 repository_ctx, 205 None, 206 "", 207 "libtensorflow_framework.so", 208 [tf_shared_library_path], 209 ["_pywrap_tensorflow_internal.lib" if _is_windows(repository_ctx) else "libtensorflow_framework.so"], 210 ) 211 212 _tpl(repository_ctx, "BUILD", { 213 "%{TF_HEADER_GENRULE}": tf_header_rule, 214 "%{TF_SHARED_LIBRARY_GENRULE}": tf_shared_library_rule, 215 }) 216 217tf_configure = repository_rule( 218 implementation = _tf_pip_impl, 219 environ = [ 220 _TF_HEADER_DIR, 221 _TF_SHARED_LIBRARY_DIR, 222 _TF_SHARED_LIBRARY_NAME, 223 ], 224) 225