1# Copyright 2024 The Bazel Authors. All rights reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15"""A simple precompiler to generate deterministic pyc files for Bazel.""" 16 17# NOTE: Imports specific to the persistent worker should only be imported 18# when a persistent worker is used. Avoiding the unnecessary imports 19# saves significant startup time for non-worker invocations. 20import argparse 21import py_compile 22import sys 23 24 25def _create_parser() -> "argparse.Namespace": 26 parser = argparse.ArgumentParser(fromfile_prefix_chars="@") 27 parser.add_argument("--invalidation_mode", default="CHECKED_HASH") 28 parser.add_argument("--optimize", type=int, default=-1) 29 parser.add_argument("--python_version") 30 31 parser.add_argument("--src", action="append", dest="srcs") 32 parser.add_argument("--src_name", action="append", dest="src_names") 33 parser.add_argument("--pyc", action="append", dest="pycs") 34 35 parser.add_argument("--persistent_worker", action="store_true") 36 parser.add_argument("--log_level", default="ERROR") 37 parser.add_argument("--worker_impl", default="async") 38 return parser 39 40 41def _compile(options: "argparse.Namespace") -> None: 42 try: 43 invalidation_mode = py_compile.PycInvalidationMode[ 44 options.invalidation_mode.upper() 45 ] 46 except KeyError as e: 47 raise ValueError( 48 f"Unknown PycInvalidationMode: {options.invalidation_mode}" 49 ) from e 50 51 if not (len(options.srcs) == len(options.src_names) == len(options.pycs)): 52 raise AssertionError( 53 "Mismatched number of --src, --src_name, and/or --pyc args" 54 ) 55 56 for src, src_name, pyc in zip(options.srcs, options.src_names, options.pycs): 57 py_compile.compile( 58 src, 59 pyc, 60 doraise=True, 61 dfile=src_name, 62 optimize=options.optimize, 63 invalidation_mode=invalidation_mode, 64 ) 65 return 0 66 67 68# A stub type alias for readability. 69# See the Bazel WorkRequest object definition: 70# https://github.com/bazelbuild/bazel/blob/master/src/main/protobuf/worker_protocol.proto 71JsonWorkerRequest = object 72 73# A stub type alias for readability. 74# See the Bazel WorkResponse object definition: 75# https://github.com/bazelbuild/bazel/blob/master/src/main/protobuf/worker_protocol.proto 76JsonWorkerResponse = object 77 78 79class _SerialPersistentWorker: 80 """Simple, synchronous, serial persistent worker.""" 81 82 def __init__(self, instream: "typing.TextIO", outstream: "typing.TextIO"): 83 self._instream = instream 84 self._outstream = outstream 85 self._parser = _create_parser() 86 87 def run(self) -> None: 88 try: 89 while True: 90 request = None 91 try: 92 request = self._get_next_request() 93 if request is None: 94 _logger.info("Empty request: exiting") 95 break 96 response = self._process_request(request) 97 if response: # May be none for cancel request 98 self._send_response(response) 99 except Exception: 100 _logger.exception("Unhandled error: request=%s", request) 101 output = ( 102 f"Unhandled error:\nRequest: {request}\n" 103 + traceback.format_exc() 104 ) 105 request_id = 0 if not request else request.get("requestId", 0) 106 self._send_response( 107 { 108 "exitCode": 3, 109 "output": output, 110 "requestId": request_id, 111 } 112 ) 113 finally: 114 _logger.info("Worker shutting down") 115 116 def _get_next_request(self) -> "object | None": 117 line = self._instream.readline() 118 if not line: 119 return None 120 return json.loads(line) 121 122 def _process_request(self, request: "JsonWorkRequest") -> "JsonWorkResponse | None": 123 if request.get("cancel"): 124 return None 125 options = self._options_from_request(request) 126 _compile(options) 127 response = { 128 "requestId": request.get("requestId", 0), 129 "exitCode": 0, 130 } 131 return response 132 133 def _options_from_request( 134 self, request: "JsonWorkResponse" 135 ) -> "argparse.Namespace": 136 options = self._parser.parse_args(request["arguments"]) 137 if request.get("sandboxDir"): 138 prefix = request["sandboxDir"] 139 options.srcs = [os.path.join(prefix, v) for v in options.srcs] 140 options.pycs = [os.path.join(prefix, v) for v in options.pycs] 141 return options 142 143 def _send_response(self, response: "JsonWorkResponse") -> None: 144 self._outstream.write(json.dumps(response) + "\n") 145 self._outstream.flush() 146 147 148class _AsyncPersistentWorker: 149 """Asynchronous, concurrent, persistent worker.""" 150 151 def __init__(self, reader: "typing.TextIO", writer: "typing.TextIO"): 152 self._reader = reader 153 self._writer = writer 154 self._parser = _create_parser() 155 self._request_id_to_task = {} 156 self._task_to_request_id = {} 157 158 @classmethod 159 async def main(cls, instream: "typing.TextIO", outstream: "typing.TextIO") -> None: 160 reader, writer = await cls._connect_streams(instream, outstream) 161 await cls(reader, writer).run() 162 163 @classmethod 164 async def _connect_streams( 165 cls, instream: "typing.TextIO", outstream: "typing.TextIO" 166 ) -> "tuple[asyncio.StreamReader, asyncio.StreamWriter]": 167 loop = asyncio.get_event_loop() 168 reader = asyncio.StreamReader() 169 protocol = asyncio.StreamReaderProtocol(reader) 170 await loop.connect_read_pipe(lambda: protocol, instream) 171 172 w_transport, w_protocol = await loop.connect_write_pipe( 173 asyncio.streams.FlowControlMixin, outstream 174 ) 175 writer = asyncio.StreamWriter(w_transport, w_protocol, reader, loop) 176 return reader, writer 177 178 async def run(self) -> None: 179 while True: 180 _logger.info("pending requests: %s", len(self._request_id_to_task)) 181 request = await self._get_next_request() 182 request_id = request.get("requestId", 0) 183 task = asyncio.create_task( 184 self._process_request(request), name=f"request_{request_id}" 185 ) 186 self._request_id_to_task[request_id] = task 187 self._task_to_request_id[task] = request_id 188 task.add_done_callback(self._handle_task_done) 189 190 async def _get_next_request(self) -> "JsonWorkRequest": 191 _logger.debug("awaiting line") 192 line = await self._reader.readline() 193 _logger.debug("recv line: %s", line) 194 return json.loads(line) 195 196 def _handle_task_done(self, task: "asyncio.Task") -> None: 197 request_id = self._task_to_request_id[task] 198 _logger.info("task done: %s %s", request_id, task) 199 del self._task_to_request_id[task] 200 del self._request_id_to_task[request_id] 201 202 async def _process_request(self, request: "JsonWorkRequest") -> None: 203 _logger.info("request %s: start: %s", request.get("requestId"), request) 204 try: 205 if request.get("cancel", False): 206 await self._process_cancel_request(request) 207 else: 208 await self._process_compile_request(request) 209 except asyncio.CancelledError: 210 _logger.info( 211 "request %s: cancel received, stopping processing", 212 request.get("requestId"), 213 ) 214 # We don't send a response because we assume the request that 215 # triggered cancelling sent the response 216 raise 217 except: 218 _logger.exception("Unhandled error: request=%s", request) 219 self._send_response( 220 { 221 "exitCode": 3, 222 "output": f"Unhandled error:\nRequest: {request}\n" 223 + traceback.format_exc(), 224 "requestId": 0 if not request else request.get("requestId", 0), 225 } 226 ) 227 228 async def _process_cancel_request(self, request: "JsonWorkRequest") -> None: 229 request_id = request.get("requestId", 0) 230 task = self._request_id_to_task.get(request_id) 231 if not task: 232 # It must be already completed, so ignore the request, per spec 233 return 234 235 task.cancel() 236 self._send_response({"requestId": request_id, "wasCancelled": True}) 237 238 async def _process_compile_request(self, request: "JsonWorkRequest") -> None: 239 options = self._options_from_request(request) 240 # _compile performs a varity of blocking IO calls, so run it separately 241 await asyncio.to_thread(_compile, options) 242 self._send_response( 243 { 244 "requestId": request.get("requestId", 0), 245 "exitCode": 0, 246 } 247 ) 248 249 def _options_from_request(self, request: "JsonWorkRequest") -> "argparse.Namespace": 250 options = self._parser.parse_args(request["arguments"]) 251 if request.get("sandboxDir"): 252 prefix = request["sandboxDir"] 253 options.srcs = [os.path.join(prefix, v) for v in options.srcs] 254 options.pycs = [os.path.join(prefix, v) for v in options.pycs] 255 return options 256 257 def _send_response(self, response: "JsonWorkResponse") -> None: 258 _logger.info("request %s: respond: %s", response.get("requestId"), response) 259 self._writer.write(json.dumps(response).encode("utf8") + b"\n") 260 261 262def main(args: "list[str]") -> int: 263 options = _create_parser().parse_args(args) 264 265 # Persistent workers are started with the `--persistent_worker` flag. 266 # See the following docs for details on persistent workers: 267 # https://bazel.build/remote/persistent 268 # https://bazel.build/remote/multiplex 269 # https://bazel.build/remote/creating 270 if options.persistent_worker: 271 global asyncio, itertools, json, logging, os, traceback, _logger 272 import asyncio 273 import itertools 274 import json 275 import logging 276 import os.path 277 import traceback 278 279 _logger = logging.getLogger("precompiler") 280 # Only configure logging for workers. This prevents non-worker 281 # invocations from spamming stderr with logging info 282 logging.basicConfig(level=getattr(logging, options.log_level)) 283 _logger.info("persistent worker: impl=%s", options.worker_impl) 284 if options.worker_impl == "serial": 285 _SerialPersistentWorker(sys.stdin, sys.stdout).run() 286 elif options.worker_impl == "async": 287 asyncio.run(_AsyncPersistentWorker.main(sys.stdin, sys.stdout)) 288 else: 289 raise ValueError(f"Unknown worker impl: {options.worker_impl}") 290 else: 291 _compile(options) 292 return 0 293 294 295if __name__ == "__main__": 296 sys.exit(main(sys.argv[1:])) 297