xref: /aosp_15_r20/external/bazelbuild-rules_python/tools/precompiler/precompiler.py (revision 60517a1edbc8ecf509223e9af94a7adec7d736b8)
1# Copyright 2024 The Bazel Authors. All rights reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""A simple precompiler to generate deterministic pyc files for Bazel."""
16
17# NOTE: Imports specific to the persistent worker should only be imported
18# when a persistent worker is used. Avoiding the unnecessary imports
19# saves significant startup time for non-worker invocations.
20import argparse
21import py_compile
22import sys
23
24
25def _create_parser() -> "argparse.Namespace":
26    parser = argparse.ArgumentParser(fromfile_prefix_chars="@")
27    parser.add_argument("--invalidation_mode", default="CHECKED_HASH")
28    parser.add_argument("--optimize", type=int, default=-1)
29    parser.add_argument("--python_version")
30
31    parser.add_argument("--src", action="append", dest="srcs")
32    parser.add_argument("--src_name", action="append", dest="src_names")
33    parser.add_argument("--pyc", action="append", dest="pycs")
34
35    parser.add_argument("--persistent_worker", action="store_true")
36    parser.add_argument("--log_level", default="ERROR")
37    parser.add_argument("--worker_impl", default="async")
38    return parser
39
40
41def _compile(options: "argparse.Namespace") -> None:
42    try:
43        invalidation_mode = py_compile.PycInvalidationMode[
44            options.invalidation_mode.upper()
45        ]
46    except KeyError as e:
47        raise ValueError(
48            f"Unknown PycInvalidationMode: {options.invalidation_mode}"
49        ) from e
50
51    if not (len(options.srcs) == len(options.src_names) == len(options.pycs)):
52        raise AssertionError(
53            "Mismatched number of --src, --src_name, and/or --pyc args"
54        )
55
56    for src, src_name, pyc in zip(options.srcs, options.src_names, options.pycs):
57        py_compile.compile(
58            src,
59            pyc,
60            doraise=True,
61            dfile=src_name,
62            optimize=options.optimize,
63            invalidation_mode=invalidation_mode,
64        )
65    return 0
66
67
68# A stub type alias for readability.
69# See the Bazel WorkRequest object definition:
70# https://github.com/bazelbuild/bazel/blob/master/src/main/protobuf/worker_protocol.proto
71JsonWorkerRequest = object
72
73# A stub type alias for readability.
74# See the Bazel WorkResponse object definition:
75# https://github.com/bazelbuild/bazel/blob/master/src/main/protobuf/worker_protocol.proto
76JsonWorkerResponse = object
77
78
79class _SerialPersistentWorker:
80    """Simple, synchronous, serial persistent worker."""
81
82    def __init__(self, instream: "typing.TextIO", outstream: "typing.TextIO"):
83        self._instream = instream
84        self._outstream = outstream
85        self._parser = _create_parser()
86
87    def run(self) -> None:
88        try:
89            while True:
90                request = None
91                try:
92                    request = self._get_next_request()
93                    if request is None:
94                        _logger.info("Empty request: exiting")
95                        break
96                    response = self._process_request(request)
97                    if response:  # May be none for cancel request
98                        self._send_response(response)
99                except Exception:
100                    _logger.exception("Unhandled error: request=%s", request)
101                    output = (
102                        f"Unhandled error:\nRequest: {request}\n"
103                        + traceback.format_exc()
104                    )
105                    request_id = 0 if not request else request.get("requestId", 0)
106                    self._send_response(
107                        {
108                            "exitCode": 3,
109                            "output": output,
110                            "requestId": request_id,
111                        }
112                    )
113        finally:
114            _logger.info("Worker shutting down")
115
116    def _get_next_request(self) -> "object | None":
117        line = self._instream.readline()
118        if not line:
119            return None
120        return json.loads(line)
121
122    def _process_request(self, request: "JsonWorkRequest") -> "JsonWorkResponse | None":
123        if request.get("cancel"):
124            return None
125        options = self._options_from_request(request)
126        _compile(options)
127        response = {
128            "requestId": request.get("requestId", 0),
129            "exitCode": 0,
130        }
131        return response
132
133    def _options_from_request(
134        self, request: "JsonWorkResponse"
135    ) -> "argparse.Namespace":
136        options = self._parser.parse_args(request["arguments"])
137        if request.get("sandboxDir"):
138            prefix = request["sandboxDir"]
139            options.srcs = [os.path.join(prefix, v) for v in options.srcs]
140            options.pycs = [os.path.join(prefix, v) for v in options.pycs]
141        return options
142
143    def _send_response(self, response: "JsonWorkResponse") -> None:
144        self._outstream.write(json.dumps(response) + "\n")
145        self._outstream.flush()
146
147
148class _AsyncPersistentWorker:
149    """Asynchronous, concurrent, persistent worker."""
150
151    def __init__(self, reader: "typing.TextIO", writer: "typing.TextIO"):
152        self._reader = reader
153        self._writer = writer
154        self._parser = _create_parser()
155        self._request_id_to_task = {}
156        self._task_to_request_id = {}
157
158    @classmethod
159    async def main(cls, instream: "typing.TextIO", outstream: "typing.TextIO") -> None:
160        reader, writer = await cls._connect_streams(instream, outstream)
161        await cls(reader, writer).run()
162
163    @classmethod
164    async def _connect_streams(
165        cls, instream: "typing.TextIO", outstream: "typing.TextIO"
166    ) -> "tuple[asyncio.StreamReader, asyncio.StreamWriter]":
167        loop = asyncio.get_event_loop()
168        reader = asyncio.StreamReader()
169        protocol = asyncio.StreamReaderProtocol(reader)
170        await loop.connect_read_pipe(lambda: protocol, instream)
171
172        w_transport, w_protocol = await loop.connect_write_pipe(
173            asyncio.streams.FlowControlMixin, outstream
174        )
175        writer = asyncio.StreamWriter(w_transport, w_protocol, reader, loop)
176        return reader, writer
177
178    async def run(self) -> None:
179        while True:
180            _logger.info("pending requests: %s", len(self._request_id_to_task))
181            request = await self._get_next_request()
182            request_id = request.get("requestId", 0)
183            task = asyncio.create_task(
184                self._process_request(request), name=f"request_{request_id}"
185            )
186            self._request_id_to_task[request_id] = task
187            self._task_to_request_id[task] = request_id
188            task.add_done_callback(self._handle_task_done)
189
190    async def _get_next_request(self) -> "JsonWorkRequest":
191        _logger.debug("awaiting line")
192        line = await self._reader.readline()
193        _logger.debug("recv line: %s", line)
194        return json.loads(line)
195
196    def _handle_task_done(self, task: "asyncio.Task") -> None:
197        request_id = self._task_to_request_id[task]
198        _logger.info("task done: %s %s", request_id, task)
199        del self._task_to_request_id[task]
200        del self._request_id_to_task[request_id]
201
202    async def _process_request(self, request: "JsonWorkRequest") -> None:
203        _logger.info("request %s: start: %s", request.get("requestId"), request)
204        try:
205            if request.get("cancel", False):
206                await self._process_cancel_request(request)
207            else:
208                await self._process_compile_request(request)
209        except asyncio.CancelledError:
210            _logger.info(
211                "request %s: cancel received, stopping processing",
212                request.get("requestId"),
213            )
214            # We don't send a response because we assume the request that
215            # triggered cancelling sent the response
216            raise
217        except:
218            _logger.exception("Unhandled error: request=%s", request)
219            self._send_response(
220                {
221                    "exitCode": 3,
222                    "output": f"Unhandled error:\nRequest: {request}\n"
223                    + traceback.format_exc(),
224                    "requestId": 0 if not request else request.get("requestId", 0),
225                }
226            )
227
228    async def _process_cancel_request(self, request: "JsonWorkRequest") -> None:
229        request_id = request.get("requestId", 0)
230        task = self._request_id_to_task.get(request_id)
231        if not task:
232            # It must be already completed, so ignore the request, per spec
233            return
234
235        task.cancel()
236        self._send_response({"requestId": request_id, "wasCancelled": True})
237
238    async def _process_compile_request(self, request: "JsonWorkRequest") -> None:
239        options = self._options_from_request(request)
240        # _compile performs a varity of blocking IO calls, so run it separately
241        await asyncio.to_thread(_compile, options)
242        self._send_response(
243            {
244                "requestId": request.get("requestId", 0),
245                "exitCode": 0,
246            }
247        )
248
249    def _options_from_request(self, request: "JsonWorkRequest") -> "argparse.Namespace":
250        options = self._parser.parse_args(request["arguments"])
251        if request.get("sandboxDir"):
252            prefix = request["sandboxDir"]
253            options.srcs = [os.path.join(prefix, v) for v in options.srcs]
254            options.pycs = [os.path.join(prefix, v) for v in options.pycs]
255        return options
256
257    def _send_response(self, response: "JsonWorkResponse") -> None:
258        _logger.info("request %s: respond: %s", response.get("requestId"), response)
259        self._writer.write(json.dumps(response).encode("utf8") + b"\n")
260
261
262def main(args: "list[str]") -> int:
263    options = _create_parser().parse_args(args)
264
265    # Persistent workers are started with the `--persistent_worker` flag.
266    # See the following docs for details on persistent workers:
267    # https://bazel.build/remote/persistent
268    # https://bazel.build/remote/multiplex
269    # https://bazel.build/remote/creating
270    if options.persistent_worker:
271        global asyncio, itertools, json, logging, os, traceback, _logger
272        import asyncio
273        import itertools
274        import json
275        import logging
276        import os.path
277        import traceback
278
279        _logger = logging.getLogger("precompiler")
280        # Only configure logging for workers. This prevents non-worker
281        # invocations from spamming stderr with logging info
282        logging.basicConfig(level=getattr(logging, options.log_level))
283        _logger.info("persistent worker: impl=%s", options.worker_impl)
284        if options.worker_impl == "serial":
285            _SerialPersistentWorker(sys.stdin, sys.stdout).run()
286        elif options.worker_impl == "async":
287            asyncio.run(_AsyncPersistentWorker.main(sys.stdin, sys.stdout))
288        else:
289            raise ValueError(f"Unknown worker impl: {options.worker_impl}")
290    else:
291        _compile(options)
292    return 0
293
294
295if __name__ == "__main__":
296    sys.exit(main(sys.argv[1:]))
297