1__all__ = 'create_subprocess_exec', 'create_subprocess_shell'
2
3import subprocess
4
5from . import events
6from . import protocols
7from . import streams
8from . import tasks
9from .log import logger
10
11
12PIPE = subprocess.PIPE
13STDOUT = subprocess.STDOUT
14DEVNULL = subprocess.DEVNULL
15
16
17class SubprocessStreamProtocol(streams.FlowControlMixin,
18                               protocols.SubprocessProtocol):
19    """Like StreamReaderProtocol, but for a subprocess."""
20
21    def __init__(self, limit, loop):
22        super().__init__(loop=loop)
23        self._limit = limit
24        self.stdin = self.stdout = self.stderr = None
25        self._transport = None
26        self._process_exited = False
27        self._pipe_fds = []
28        self._stdin_closed = self._loop.create_future()
29
30    def __repr__(self):
31        info = [self.__class__.__name__]
32        if self.stdin is not None:
33            info.append(f'stdin={self.stdin!r}')
34        if self.stdout is not None:
35            info.append(f'stdout={self.stdout!r}')
36        if self.stderr is not None:
37            info.append(f'stderr={self.stderr!r}')
38        return '<{}>'.format(' '.join(info))
39
40    def connection_made(self, transport):
41        self._transport = transport
42
43        stdout_transport = transport.get_pipe_transport(1)
44        if stdout_transport is not None:
45            self.stdout = streams.StreamReader(limit=self._limit,
46                                               loop=self._loop)
47            self.stdout.set_transport(stdout_transport)
48            self._pipe_fds.append(1)
49
50        stderr_transport = transport.get_pipe_transport(2)
51        if stderr_transport is not None:
52            self.stderr = streams.StreamReader(limit=self._limit,
53                                               loop=self._loop)
54            self.stderr.set_transport(stderr_transport)
55            self._pipe_fds.append(2)
56
57        stdin_transport = transport.get_pipe_transport(0)
58        if stdin_transport is not None:
59            self.stdin = streams.StreamWriter(stdin_transport,
60                                              protocol=self,
61                                              reader=None,
62                                              loop=self._loop)
63
64    def pipe_data_received(self, fd, data):
65        if fd == 1:
66            reader = self.stdout
67        elif fd == 2:
68            reader = self.stderr
69        else:
70            reader = None
71        if reader is not None:
72            reader.feed_data(data)
73
74    def pipe_connection_lost(self, fd, exc):
75        if fd == 0:
76            pipe = self.stdin
77            if pipe is not None:
78                pipe.close()
79            self.connection_lost(exc)
80            if exc is None:
81                self._stdin_closed.set_result(None)
82            else:
83                self._stdin_closed.set_exception(exc)
84                # Since calling `wait_closed()` is not mandatory,
85                # we shouldn't log the traceback if this is not awaited.
86                self._stdin_closed._log_traceback = False
87            return
88        if fd == 1:
89            reader = self.stdout
90        elif fd == 2:
91            reader = self.stderr
92        else:
93            reader = None
94        if reader is not None:
95            if exc is None:
96                reader.feed_eof()
97            else:
98                reader.set_exception(exc)
99
100        if fd in self._pipe_fds:
101            self._pipe_fds.remove(fd)
102        self._maybe_close_transport()
103
104    def process_exited(self):
105        self._process_exited = True
106        self._maybe_close_transport()
107
108    def _maybe_close_transport(self):
109        if len(self._pipe_fds) == 0 and self._process_exited:
110            self._transport.close()
111            self._transport = None
112
113    def _get_close_waiter(self, stream):
114        if stream is self.stdin:
115            return self._stdin_closed
116
117
118class Process:
119    def __init__(self, transport, protocol, loop):
120        self._transport = transport
121        self._protocol = protocol
122        self._loop = loop
123        self.stdin = protocol.stdin
124        self.stdout = protocol.stdout
125        self.stderr = protocol.stderr
126        self.pid = transport.get_pid()
127
128    def __repr__(self):
129        return f'<{self.__class__.__name__} {self.pid}>'
130
131    @property
132    def returncode(self):
133        return self._transport.get_returncode()
134
135    async def wait(self):
136        """Wait until the process exit and return the process return code."""
137        return await self._transport._wait()
138
139    def send_signal(self, signal):
140        self._transport.send_signal(signal)
141
142    def terminate(self):
143        self._transport.terminate()
144
145    def kill(self):
146        self._transport.kill()
147
148    async def _feed_stdin(self, input):
149        debug = self._loop.get_debug()
150        self.stdin.write(input)
151        if debug:
152            logger.debug(
153                '%r communicate: feed stdin (%s bytes)', self, len(input))
154        try:
155            await self.stdin.drain()
156        except (BrokenPipeError, ConnectionResetError) as exc:
157            # communicate() ignores BrokenPipeError and ConnectionResetError
158            if debug:
159                logger.debug('%r communicate: stdin got %r', self, exc)
160
161        if debug:
162            logger.debug('%r communicate: close stdin', self)
163        self.stdin.close()
164
165    async def _noop(self):
166        return None
167
168    async def _read_stream(self, fd):
169        transport = self._transport.get_pipe_transport(fd)
170        if fd == 2:
171            stream = self.stderr
172        else:
173            assert fd == 1
174            stream = self.stdout
175        if self._loop.get_debug():
176            name = 'stdout' if fd == 1 else 'stderr'
177            logger.debug('%r communicate: read %s', self, name)
178        output = await stream.read()
179        if self._loop.get_debug():
180            name = 'stdout' if fd == 1 else 'stderr'
181            logger.debug('%r communicate: close %s', self, name)
182        transport.close()
183        return output
184
185    async def communicate(self, input=None):
186        if input is not None:
187            stdin = self._feed_stdin(input)
188        else:
189            stdin = self._noop()
190        if self.stdout is not None:
191            stdout = self._read_stream(1)
192        else:
193            stdout = self._noop()
194        if self.stderr is not None:
195            stderr = self._read_stream(2)
196        else:
197            stderr = self._noop()
198        stdin, stdout, stderr = await tasks.gather(stdin, stdout, stderr)
199        await self.wait()
200        return (stdout, stderr)
201
202
203async def create_subprocess_shell(cmd, stdin=None, stdout=None, stderr=None,
204                                  limit=streams._DEFAULT_LIMIT, **kwds):
205    loop = events.get_running_loop()
206    protocol_factory = lambda: SubprocessStreamProtocol(limit=limit,
207                                                        loop=loop)
208    transport, protocol = await loop.subprocess_shell(
209        protocol_factory,
210        cmd, stdin=stdin, stdout=stdout,
211        stderr=stderr, **kwds)
212    return Process(transport, protocol, loop)
213
214
215async def create_subprocess_exec(program, *args, stdin=None, stdout=None,
216                                 stderr=None, limit=streams._DEFAULT_LIMIT,
217                                 **kwds):
218    loop = events.get_running_loop()
219    protocol_factory = lambda: SubprocessStreamProtocol(limit=limit,
220                                                        loop=loop)
221    transport, protocol = await loop.subprocess_exec(
222        protocol_factory,
223        program, *args,
224        stdin=stdin, stdout=stdout,
225        stderr=stderr, **kwds)
226    return Process(transport, protocol, loop)
227