xref: /aosp_15_r20/external/toolchain-utils/cros_utils/command_executer.py (revision 760c253c1ed00ce9abd48f8546f08516e57485fe)
1#!/usr/bin/env python3
2# -*- coding: utf-8 -*-
3# Copyright 2011 The ChromiumOS Authors
4# Use of this source code is governed by a BSD-style license that can be
5# found in the LICENSE file.
6
7"""Utilities to run commands in outside/inside chroot and on the board."""
8
9
10import getpass
11import os
12import re
13import select
14import signal
15import subprocess
16import sys
17import tempfile
18import time
19
20from cros_utils import logger
21
22
23mock_default = False
24
25CHROMEOS_SCRIPTS_DIR = "/mnt/host/source/src/scripts"
26LOG_LEVEL = ("none", "quiet", "average", "verbose")
27
28
29def InitCommandExecuter(mock=False):
30    # pylint: disable=global-statement
31    global mock_default
32    # Whether to default to a mock command executer or not
33    mock_default = mock
34
35
36def GetCommandExecuter(logger_to_set=None, mock=False, log_level="verbose"):
37    # If the default is a mock executer, always return one.
38    if mock_default or mock:
39        return MockCommandExecuter(log_level, logger_to_set)
40    else:
41        return CommandExecuter(log_level, logger_to_set)
42
43
44class CommandExecuter(object):
45    """Provides several methods to execute commands on several environments."""
46
47    def __init__(self, log_level, logger_to_set=None):
48        self.log_level = log_level
49        if log_level == "none":
50            self.logger = None
51        else:
52            if logger_to_set is not None:
53                self.logger = logger_to_set
54            else:
55                self.logger = logger.GetLogger()
56
57    def GetLogLevel(self):
58        return self.log_level
59
60    def SetLogLevel(self, log_level):
61        self.log_level = log_level
62
63    def RunCommandGeneric(
64        self,
65        cmd,
66        return_output=False,
67        machine=None,
68        username=None,
69        command_terminator=None,
70        command_timeout=None,
71        terminated_timeout=10,
72        print_to_console=True,
73        env=None,
74        except_handler=lambda p, e: None,
75    ):
76        """Run a command.
77
78        Returns triplet (returncode, stdout, stderr).
79        """
80
81        cmd = str(cmd)
82
83        if self.log_level == "quiet":
84            print_to_console = False
85
86        if self.log_level == "verbose":
87            self.logger.LogCmd(cmd, machine, username, print_to_console)
88        elif self.logger:
89            self.logger.LogCmdToFileOnly(cmd, machine, username)
90        if command_terminator and command_terminator.IsTerminated():
91            if self.logger:
92                self.logger.LogError(
93                    "Command was terminated!", print_to_console
94                )
95            return (1, "", "")
96
97        if machine is not None:
98            user = ""
99            if username is not None:
100                user = username + "@"
101            cmd = "ssh -t -t %s%s -- '%s'" % (user, machine, cmd)
102
103        # We use setsid so that the child will have a different session id
104        # and we can easily kill the process group. This is also important
105        # because the child will be disassociated from the parent terminal.
106        # In this way the child cannot mess the parent's terminal.
107        p = None
108        try:
109            # pylint: disable=bad-option-value, subprocess-popen-preexec-fn
110            p = subprocess.Popen(
111                cmd,
112                stdout=subprocess.PIPE,
113                stderr=subprocess.PIPE,
114                shell=True,
115                preexec_fn=os.setsid,
116                executable="/bin/bash",
117                env=env,
118            )
119
120            full_stdout = ""
121            full_stderr = ""
122
123            # Pull output from pipes, send it to file/stdout/string
124            out = err = None
125            pipes = [p.stdout, p.stderr]
126
127            my_poll = select.poll()
128            my_poll.register(p.stdout, select.POLLIN)
129            my_poll.register(p.stderr, select.POLLIN)
130
131            terminated_time = None
132            started_time = time.time()
133
134            while pipes:
135                if command_terminator and command_terminator.IsTerminated():
136                    os.killpg(os.getpgid(p.pid), signal.SIGTERM)
137                    if self.logger:
138                        self.logger.LogError(
139                            "Command received termination request. "
140                            "Killed child process group.",
141                            print_to_console,
142                        )
143                    break
144
145                l = my_poll.poll(100)
146                for (fd, _) in l:
147                    if fd == p.stdout.fileno():
148                        out = os.read(p.stdout.fileno(), 16384).decode("utf8")
149                        if return_output:
150                            full_stdout += out
151                        if self.logger:
152                            self.logger.LogCommandOutput(out, print_to_console)
153                        if out == "":
154                            pipes.remove(p.stdout)
155                            my_poll.unregister(p.stdout)
156                    if fd == p.stderr.fileno():
157                        err = os.read(p.stderr.fileno(), 16384).decode("utf8")
158                        if return_output:
159                            full_stderr += err
160                        if self.logger:
161                            self.logger.LogCommandError(err, print_to_console)
162                        if err == "":
163                            pipes.remove(p.stderr)
164                            my_poll.unregister(p.stderr)
165
166                if p.poll() is not None:
167                    if terminated_time is None:
168                        terminated_time = time.time()
169                    elif (
170                        terminated_timeout is not None
171                        and time.time() - terminated_time > terminated_timeout
172                    ):
173                        if self.logger:
174                            self.logger.LogWarning(
175                                "Timeout of %s seconds reached since "
176                                "process termination." % terminated_timeout,
177                                print_to_console,
178                            )
179                        break
180
181                if (
182                    command_timeout is not None
183                    and time.time() - started_time > command_timeout
184                ):
185                    os.killpg(os.getpgid(p.pid), signal.SIGTERM)
186                    if self.logger:
187                        self.logger.LogWarning(
188                            "Timeout of %s seconds reached since process"
189                            "started. Killed child process group."
190                            % command_timeout,
191                            print_to_console,
192                        )
193                    break
194
195                if out == err == "":
196                    break
197
198            p.wait()
199            if return_output:
200                return (p.returncode, full_stdout, full_stderr)
201            return (p.returncode, "", "")
202        except BaseException as err:
203            except_handler(p, err)
204            raise
205
206    def RunCommand(self, *args, **kwargs):
207        """Run a command.
208
209        Takes the same arguments as RunCommandGeneric except for return_output.
210        Returns a single value returncode.
211        """
212        # Make sure that args does not overwrite 'return_output'
213        assert len(args) <= 1
214        assert "return_output" not in kwargs
215        kwargs["return_output"] = False
216        return self.RunCommandGeneric(*args, **kwargs)[0]
217
218    def RunCommandWExceptionCleanup(self, *args, **kwargs):
219        """Run a command and kill process if exception is thrown.
220
221        Takes the same arguments as RunCommandGeneric except for except_handler.
222        Returns same as RunCommandGeneric.
223        """
224
225        def KillProc(proc, _):
226            if proc:
227                os.killpg(os.getpgid(proc.pid), signal.SIGTERM)
228
229        # Make sure that args does not overwrite 'except_handler'
230        assert len(args) <= 8
231        assert "except_handler" not in kwargs
232        kwargs["except_handler"] = KillProc
233        return self.RunCommandGeneric(*args, **kwargs)
234
235    def RunCommandWOutput(self, *args, **kwargs):
236        """Run a command.
237
238        Takes the same arguments as RunCommandGeneric except for return_output.
239        Returns a triplet (returncode, stdout, stderr).
240        """
241        # Make sure that args does not overwrite 'return_output'
242        assert len(args) <= 1
243        assert "return_output" not in kwargs
244        kwargs["return_output"] = True
245        return self.RunCommandGeneric(*args, **kwargs)
246
247    def RemoteAccessInitCommand(self, chromeos_root, machine, port=None):
248        command = ""
249        command += "\nset -- --remote=" + machine
250        if port:
251            command += " --ssh_port=" + port
252        command += "\n. " + chromeos_root + "/src/scripts/common.sh"
253        command += "\n. " + chromeos_root + "/src/scripts/remote_access.sh"
254        command += "\nTMP=$(mktemp -d)"
255        command += '\nFLAGS "$@" || exit 1'
256        command += "\nremote_access_init"
257        return command
258
259    def WriteToTempShFile(self, contents):
260        with tempfile.NamedTemporaryFile(
261            "w",
262            encoding="utf-8",
263            delete=False,
264            prefix=os.uname()[1],
265            suffix=".sh",
266        ) as f:
267            f.write("#!/bin/bash\n")
268            f.write(contents)
269            f.flush()
270        return f.name
271
272    def CrosLearnBoard(self, chromeos_root, machine):
273        command = self.RemoteAccessInitCommand(chromeos_root, machine)
274        command += "\nlearn_board"
275        command += "\necho ${FLAGS_board}"
276        retval, output, _ = self.RunCommandWOutput(command)
277        if self.logger:
278            self.logger.LogFatalIf(retval, "learn_board command failed")
279        elif retval:
280            sys.exit(1)
281        return output.split()[-1]
282
283    def CrosRunCommandGeneric(
284        self,
285        cmd,
286        return_output=False,
287        machine=None,
288        command_terminator=None,
289        chromeos_root=None,
290        command_timeout=None,
291        terminated_timeout=10,
292        print_to_console=True,
293    ):
294        """Run a command on a ChromeOS box.
295
296        Returns triplet (returncode, stdout, stderr).
297        """
298
299        if self.log_level != "verbose":
300            print_to_console = False
301
302        if self.logger:
303            self.logger.LogCmd(cmd, print_to_console=print_to_console)
304            self.logger.LogFatalIf(not machine, "No machine provided!")
305            self.logger.LogFatalIf(
306                not chromeos_root, "chromeos_root not given!"
307            )
308        else:
309            if not chromeos_root or not machine:
310                sys.exit(1)
311        chromeos_root = os.path.expanduser(chromeos_root)
312
313        port = None
314        if ":" in machine:
315            machine, port = machine.split(":")
316        # Write all commands to a file.
317        command_file = self.WriteToTempShFile(cmd)
318        retval = self.CopyFiles(
319            command_file,
320            command_file,
321            dest_machine=machine,
322            dest_port=port,
323            command_terminator=command_terminator,
324            chromeos_root=chromeos_root,
325            dest_cros=True,
326            recursive=False,
327            print_to_console=print_to_console,
328        )
329        if retval:
330            if self.logger:
331                self.logger.LogError(
332                    "Could not run remote command on machine."
333                    " Is the machine up?"
334                )
335            return (retval, "", "")
336
337        command = self.RemoteAccessInitCommand(chromeos_root, machine, port)
338        command += "\nremote_sh bash %s" % command_file
339        command += '\nl_retval=$?; echo "$REMOTE_OUT"; exit $l_retval'
340        retval = self.RunCommandGeneric(
341            command,
342            return_output,
343            command_terminator=command_terminator,
344            command_timeout=command_timeout,
345            terminated_timeout=terminated_timeout,
346            print_to_console=print_to_console,
347        )
348        if return_output:
349            connect_signature = (
350                "Initiating first contact with remote host\n"
351                + "Connection OK\n"
352            )
353            connect_signature_re = re.compile(connect_signature)
354            modded_retval = list(retval)
355            modded_retval[1] = connect_signature_re.sub("", retval[1])
356            return modded_retval
357        return retval
358
359    def CrosRunCommand(self, *args, **kwargs):
360        """Run a command on a ChromeOS box.
361
362        Takes the same arguments as CrosRunCommandGeneric except for return_output.
363        Returns a single value returncode.
364        """
365        # Make sure that args does not overwrite 'return_output'
366        assert len(args) <= 1
367        assert "return_output" not in kwargs
368        kwargs["return_output"] = False
369        return self.CrosRunCommandGeneric(*args, **kwargs)[0]
370
371    def CrosRunCommandWOutput(self, *args, **kwargs):
372        """Run a command on a ChromeOS box.
373
374        Takes the same arguments as CrosRunCommandGeneric except for return_output.
375        Returns a triplet (returncode, stdout, stderr).
376        """
377        # Make sure that args does not overwrite 'return_output'
378        assert len(args) <= 1
379        assert "return_output" not in kwargs
380        kwargs["return_output"] = True
381        return self.CrosRunCommandGeneric(*args, **kwargs)
382
383    def ChrootRunCommandGeneric(
384        self,
385        chromeos_root,
386        command,
387        return_output=False,
388        command_terminator=None,
389        command_timeout=None,
390        terminated_timeout=10,
391        print_to_console=True,
392        cros_sdk_options="",
393        env=None,
394    ):
395        """Runs a command within the chroot.
396
397        Returns triplet (returncode, stdout, stderr).
398        """
399
400        if self.log_level != "verbose":
401            print_to_console = False
402
403        if self.logger:
404            self.logger.LogCmd(command, print_to_console=print_to_console)
405
406        with tempfile.NamedTemporaryFile(
407            "w",
408            encoding="utf-8",
409            delete=False,
410            dir=os.path.join(chromeos_root, "src/scripts"),
411            suffix=".sh",
412            prefix="in_chroot_cmd",
413        ) as f:
414            f.write("#!/bin/bash\n")
415            f.write(command)
416            f.write("\n")
417            f.flush()
418
419        command_file = f.name
420        os.chmod(command_file, 0o777)
421
422        # if return_output is set, run a test command first to make sure that
423        # the chroot already exists. We want the final returned output to skip
424        # the output from chroot creation steps.
425        if return_output:
426            ret = self.RunCommand(
427                "cd %s; cros_sdk %s -- true"
428                % (chromeos_root, cros_sdk_options),
429                env=env,
430                # Give this command a long time to execute; it might involve setting
431                # the chroot up, or running fstrim on its image file. Both of these
432                # operations can take well over the timeout default of 10 seconds.
433                terminated_timeout=5 * 60,
434            )
435            if ret:
436                return (ret, "", "")
437
438        # Run command_file inside the chroot, making sure that any "~" is expanded
439        # by the shell inside the chroot, not outside.
440        command = "cd %s; cros_sdk %s -- bash -c '%s/%s'" % (
441            chromeos_root,
442            cros_sdk_options,
443            CHROMEOS_SCRIPTS_DIR,
444            os.path.basename(command_file),
445        )
446        ret = self.RunCommandGeneric(
447            command,
448            return_output,
449            command_terminator=command_terminator,
450            command_timeout=command_timeout,
451            terminated_timeout=terminated_timeout,
452            print_to_console=print_to_console,
453            env=env,
454        )
455        os.remove(command_file)
456        return ret
457
458    def ChrootRunCommand(self, *args, **kwargs):
459        """Runs a command within the chroot.
460
461        Takes the same arguments as ChrootRunCommandGeneric except for
462        return_output.
463        Returns a single value returncode.
464        """
465        # Make sure that args does not overwrite 'return_output'
466        assert len(args) <= 2
467        assert "return_output" not in kwargs
468        kwargs["return_output"] = False
469        return self.ChrootRunCommandGeneric(*args, **kwargs)[0]
470
471    def ChrootRunCommandWOutput(self, *args, **kwargs):
472        """Runs a command within the chroot.
473
474        Takes the same arguments as ChrootRunCommandGeneric except for
475        return_output.
476        Returns a triplet (returncode, stdout, stderr).
477        """
478        # Make sure that args does not overwrite 'return_output'
479        assert len(args) <= 2
480        assert "return_output" not in kwargs
481        kwargs["return_output"] = True
482        return self.ChrootRunCommandGeneric(*args, **kwargs)
483
484    def RunCommands(
485        self, cmdlist, machine=None, username=None, command_terminator=None
486    ):
487        cmd = " ;\n".join(cmdlist)
488        return self.RunCommand(
489            cmd,
490            machine=machine,
491            username=username,
492            command_terminator=command_terminator,
493        )
494
495    def CopyFiles(
496        self,
497        src,
498        dest,
499        src_machine=None,
500        src_port=None,
501        dest_machine=None,
502        dest_port=None,
503        src_user=None,
504        dest_user=None,
505        recursive=True,
506        command_terminator=None,
507        chromeos_root=None,
508        src_cros=False,
509        dest_cros=False,
510        print_to_console=True,
511    ):
512        src = os.path.expanduser(src)
513        dest = os.path.expanduser(dest)
514
515        if recursive:
516            src = src + "/"
517            dest = dest + "/"
518
519        if src_cros or dest_cros:
520            if self.logger:
521                self.logger.LogFatalIf(
522                    src_cros == dest_cros,
523                    "Only one of src_cros and desc_cros can " "be True.",
524                )
525                self.logger.LogFatalIf(
526                    not chromeos_root, "chromeos_root not given!"
527                )
528            elif src_cros == dest_cros or not chromeos_root:
529                sys.exit(1)
530            if src_cros:
531                cros_machine = src_machine
532                cros_port = src_port
533                host_machine = dest_machine
534                host_user = dest_user
535            else:
536                cros_machine = dest_machine
537                cros_port = dest_port
538                host_machine = src_machine
539                host_user = src_user
540
541            command = self.RemoteAccessInitCommand(
542                chromeos_root, cros_machine, cros_port
543            )
544            ssh_command = (
545                "ssh -o StrictHostKeyChecking=no"
546                + " -o UserKnownHostsFile=$(mktemp)"
547                + " -i $TMP_PRIVATE_KEY"
548            )
549            if cros_port:
550                ssh_command += " -p %s" % cros_port
551            rsync_prefix = '\nrsync -r -e "%s" ' % ssh_command
552            if dest_cros:
553                command += rsync_prefix + "%s root@%s:%s" % (
554                    src,
555                    cros_machine,
556                    dest,
557                )
558            else:
559                command += rsync_prefix + "root@%s:%s %s" % (
560                    cros_machine,
561                    src,
562                    dest,
563                )
564
565            return self.RunCommand(
566                command,
567                machine=host_machine,
568                username=host_user,
569                command_terminator=command_terminator,
570                print_to_console=print_to_console,
571            )
572
573        if dest_machine == src_machine:
574            command = "rsync -a %s %s" % (src, dest)
575        else:
576            if src_machine is None:
577                src_machine = os.uname()[1]
578                src_user = getpass.getuser()
579            command = "rsync -a %s@%s:%s %s" % (
580                src_user,
581                src_machine,
582                src,
583                dest,
584            )
585        return self.RunCommand(
586            command,
587            machine=dest_machine,
588            username=dest_user,
589            command_terminator=command_terminator,
590            print_to_console=print_to_console,
591        )
592
593    def RunCommand2(
594        self,
595        cmd,
596        cwd=None,
597        line_consumer=None,
598        timeout=None,
599        shell=True,
600        join_stderr=True,
601        env=None,
602        except_handler=lambda p, e: None,
603    ):
604        """Run the command with an extra feature line_consumer.
605
606        This version allow developers to provide a line_consumer which will be
607        fed execution output lines.
608
609        A line_consumer is a callback, which is given a chance to run for each
610        line the execution outputs (either to stdout or stderr). The
611        line_consumer must accept one and exactly one dict argument, the dict
612        argument has these items -
613          'line'   -  The line output by the binary. Notice, this string includes
614                      the trailing '\n'.
615          'output' -  Whether this is a stdout or stderr output, values are either
616                      'stdout' or 'stderr'. When join_stderr is True, this value
617                      will always be 'output'.
618          'pobject' - The object used to control execution, for example, call
619                      pobject.kill().
620
621        Note: As this is written, the stdin for the process executed is
622        not associated with the stdin of the caller of this routine.
623
624        Args:
625          cmd: Command in a single string.
626          cwd: Working directory for execution.
627          line_consumer: A function that will ba called by this function. See above
628            for details.
629          timeout: terminate command after this timeout.
630          shell: Whether to use a shell for execution.
631          join_stderr: Whether join stderr to stdout stream.
632          env: Execution environment.
633          except_handler: Callback for when exception is thrown during command
634            execution. Passed process object and exception.
635
636        Returns:
637          Execution return code.
638
639        Raises:
640          child_exception: if fails to start the command process (missing
641                           permission, no such file, etc)
642        """
643
644        class StreamHandler(object):
645            """Internal utility class."""
646
647            def __init__(self, pobject, fd, name, line_consumer):
648                self._pobject = pobject
649                self._fd = fd
650                self._name = name
651                self._buf = ""
652                self._line_consumer = line_consumer
653
654            def read_and_notify_line(self):
655                t = os.read(fd, 1024)
656                self._buf = self._buf + t
657                self.notify_line()
658
659            def notify_line(self):
660                p = self._buf.find("\n")
661                while p >= 0:
662                    self._line_consumer(
663                        line=self._buf[: p + 1],
664                        output=self._name,
665                        pobject=self._pobject,
666                    )
667                    if p < len(self._buf) - 1:
668                        self._buf = self._buf[p + 1 :]
669                        p = self._buf.find("\n")
670                    else:
671                        self._buf = ""
672                        p = -1
673                        break
674
675            def notify_eos(self):
676                # Notify end of stream. The last line may not end with a '\n'.
677                if self._buf != "":
678                    self._line_consumer(
679                        line=self._buf, output=self._name, pobject=self._pobject
680                    )
681                    self._buf = ""
682
683        if self.log_level == "verbose":
684            self.logger.LogCmd(cmd)
685        elif self.logger:
686            self.logger.LogCmdToFileOnly(cmd)
687
688        # We use setsid so that the child will have a different session id
689        # and we can easily kill the process group. This is also important
690        # because the child will be disassociated from the parent terminal.
691        # In this way the child cannot mess the parent's terminal.
692        pobject = None
693        try:
694            # pylint: disable=bad-option-value, subprocess-popen-preexec-fn
695            pobject = subprocess.Popen(
696                cmd,
697                cwd=cwd,
698                bufsize=1024,
699                env=env,
700                shell=shell,
701                universal_newlines=True,
702                stdout=subprocess.PIPE,
703                stderr=subprocess.STDOUT if join_stderr else subprocess.PIPE,
704                preexec_fn=os.setsid,
705            )
706
707            # We provide a default line_consumer
708            if line_consumer is None:
709                line_consumer = lambda **d: None
710            start_time = time.time()
711            poll = select.poll()
712            outfd = pobject.stdout.fileno()
713            poll.register(outfd, select.POLLIN | select.POLLPRI)
714            handlermap = {
715                outfd: StreamHandler(pobject, outfd, "stdout", line_consumer)
716            }
717            if not join_stderr:
718                errfd = pobject.stderr.fileno()
719                poll.register(errfd, select.POLLIN | select.POLLPRI)
720                handlermap[errfd] = StreamHandler(
721                    pobject, errfd, "stderr", line_consumer
722                )
723            while handlermap:
724                readables = poll.poll(300)
725                for (fd, evt) in readables:
726                    handler = handlermap[fd]
727                    if evt & (select.POLLPRI | select.POLLIN):
728                        handler.read_and_notify_line()
729                    elif evt & (
730                        select.POLLHUP | select.POLLERR | select.POLLNVAL
731                    ):
732                        handler.notify_eos()
733                        poll.unregister(fd)
734                        del handlermap[fd]
735
736                if timeout is not None and (time.time() - start_time > timeout):
737                    os.killpg(os.getpgid(pobject.pid), signal.SIGTERM)
738
739            return pobject.wait()
740        except BaseException as err:
741            except_handler(pobject, err)
742            raise
743
744
745class MockCommandExecuter(CommandExecuter):
746    """Mock class for class CommandExecuter."""
747
748    def RunCommandGeneric(
749        self,
750        cmd,
751        return_output=False,
752        machine=None,
753        username=None,
754        command_terminator=None,
755        command_timeout=None,
756        terminated_timeout=10,
757        print_to_console=True,
758        env=None,
759        except_handler=lambda p, e: None,
760    ):
761        assert not command_timeout
762        cmd = str(cmd)
763        if machine is None:
764            machine = "localhost"
765        if username is None:
766            username = "current"
767        logger.GetLogger().LogCmd(
768            "(Mock) " + cmd, machine, username, print_to_console
769        )
770        return (0, "", "")
771
772    def RunCommand(self, *args, **kwargs):
773        assert "return_output" not in kwargs
774        kwargs["return_output"] = False
775        return self.RunCommandGeneric(*args, **kwargs)[0]
776
777    def RunCommandWOutput(self, *args, **kwargs):
778        assert "return_output" not in kwargs
779        kwargs["return_output"] = True
780        return self.RunCommandGeneric(*args, **kwargs)
781
782
783class CommandTerminator(object):
784    """Object to request termination of a command in execution."""
785
786    def __init__(self):
787        self.terminated = False
788
789    def Terminate(self):
790        self.terminated = True
791
792    def IsTerminated(self):
793        return self.terminated
794