xref: /aosp_15_r20/tools/asuite/atest/integration_tests/split_build_test_script.py (revision c2e18aaa1096c836b086f94603d04f4eb9cf37f5)
1#!/usr/bin/env python3
2#
3# Copyright 2023, The Android Open Source Project
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9#     http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17"""Module to facilitate integration test within the build and test environment.
18
19This module provides utilities for running tests in both build and test
20environments, managing environment variables, and snapshotting the workspace for
21restoration later.
22"""
23
24import argparse
25import atexit
26import concurrent.futures
27import copy
28import dataclasses
29import datetime
30import functools
31import itertools
32import logging
33import multiprocessing
34import os
35import pathlib
36import shutil
37import subprocess
38import sys
39import tarfile
40import tempfile
41import time
42import traceback
43from typing import Any, Callable, Iterator
44import unittest
45import zipfile
46
47from snapshot import Snapshot
48
49# Env key for the storage tar path.
50SNAPSHOT_STORAGE_TAR_KEY = 'SNAPSHOT_STORAGE_TAR_PATH'
51
52# Env key for the repo root
53ANDROID_BUILD_TOP_KEY = 'ANDROID_BUILD_TOP'
54
55
56class IntegrationTestConfiguration:
57  """Internal class to store integration test configuration."""
58
59  device_serial: str = None
60  is_build_env: bool = False
61  is_test_env: bool = False
62  snapshot_storage_path: pathlib.Path = None
63  snapshot_storage_tar_path: pathlib.Path = None
64  workspace_path: pathlib.Path = None
65  is_tar_snapshot: bool = False
66
67
68class StepInput:
69  """Input information for a build/test step."""
70
71  def __init__(self, env, repo_root, config, objs):
72    self._env = env
73    self._repo_root = repo_root
74    self._config = config
75    self._objs = objs
76
77  def get_device_serial_args_or_empty(self) -> str:
78    """Gets command arguments for device serial. May return empty string."""
79    # TODO: b/336839543 - Remove this method when we deprecate the support to
80    # run the integration test directly through 'python **.py' command.
81    if self._config.device_serial:
82      return ' -s ' + self._config.device_serial
83    if ANDROID_BUILD_TOP_KEY not in os.environ and self._config.is_test_env:
84      # Likely in test lab environment, where connected devices can are
85      # allocated to other tests. In this case we must explicitly set device
86      # serials in any atest calls .
87      raise RuntimeError('Device serial is required but not set')
88    # Empty is allowed because it allows tradefed to decide which device to
89    # select in local run.
90    return ''
91
92  def get_device_serial(self) -> str:
93    """Returns the serial of the connected device. Throws if not set."""
94    if not self._config.device_serial:
95      raise RuntimeError('Device serial is not set')
96    return self._config.device_serial
97
98  def get_env(self):
99    """Get environment variables."""
100    return self._env
101
102  def get_repo_root(self) -> str:
103    """Get repo root directory."""
104    return self._repo_root
105
106  def get_obj(self, name: str) -> Any:
107    """Get an object saved in previous snapshot."""
108    return self._objs.get(name, None)
109
110  def get_config(self) -> IntegrationTestConfiguration:
111    """Get the integration test configuration."""
112    return self._config
113
114
115class StepOutput:
116  """Output information generated from a build step."""
117
118  def __init__(self):
119    self._snapshot_include_paths: list[str] = []
120    self._snapshot_exclude_paths: list[str] = []
121    self._snapshot_env_keys: list[str] = []
122    self._snapshot_objs: dict[str, Any] = {}
123
124  def add_snapshot_include_paths(self, paths: list[str]) -> None:
125    """Add paths to include in snapshot artifacts."""
126    self._snapshot_include_paths.extend(paths)
127
128  def set_snapshot_include_paths(self, paths: list[str]) -> None:
129    """Set the snapshot include paths.
130
131    Note that the default include paths will be removed.
132    Use add_snapshot_include_paths if that's not intended.
133
134    Args:
135        paths: The new list of paths to include for snapshot.
136    """
137    self._snapshot_include_paths.clear()
138    self._snapshot_include_paths.extend(paths)
139
140  def add_snapshot_exclude_paths(self, paths: list[str]) -> None:
141    """Add paths to exclude from snapshot artifacts."""
142    self._snapshot_exclude_paths.extend(paths)
143
144  def add_snapshot_env_keys(self, keys: list[str]) -> None:
145    """Add environment variable keys for snapshot."""
146    self._snapshot_env_keys.extend(keys)
147
148  def add_snapshot_obj(self, name: str, obj: Any):
149    """Add objects to save in snapshot."""
150    self._snapshot_objs[name] = obj
151
152  def get_snapshot_include_paths(self):
153    """Returns the stored snapshot include path list."""
154    return self._snapshot_include_paths
155
156  def get_snapshot_exclude_paths(self):
157    """Returns the stored snapshot exclude path list."""
158    return self._snapshot_exclude_paths
159
160  def get_snapshot_env_keys(self):
161    """Returns the stored snapshot env key list."""
162    return self._snapshot_env_keys
163
164  def get_snapshot_objs(self):
165    """Returns the stored snapshot object dictionary."""
166    return self._snapshot_objs
167
168
169class SplitBuildTestScript:
170  """Utility for running integration test in build and test environment."""
171
172  def __init__(self, name: str, config: IntegrationTestConfiguration) -> None:
173    self._config = config
174    self._id: str = name
175    self._snapshot: Snapshot = Snapshot(self._config.snapshot_storage_path)
176    self._has_already_run: bool = False
177    self._steps: list[self._Step] = []
178    self._snapshot_restore_exclude_paths: list[str] = []
179
180  def get_config(self) -> IntegrationTestConfiguration:
181    return self._config
182
183  def add_build_step(self, step_func: Callable[StepInput, StepOutput]):
184    """Add a build step.
185
186    Args:
187        step_func: A function that takes a StepInput object and returns a
188          StepOutput object.
189
190    Raises:
191        RuntimeError: Unexpected step orders detected.
192    """
193    if self._steps and isinstance(self._steps[-1], self._BuildStep):
194      raise RuntimeError(
195          'Two adjacent build steps are unnecessary. Combine them.'
196      )
197    self._steps.append(self._BuildStep(step_func))
198
199  def add_test_step(self, step_func: Callable[StepInput, None]):
200    """Add a test step.
201
202    Args:
203        step_func: A function that takes a StepInput object.
204
205    Raises:
206        RuntimeError: Unexpected step orders detected.
207    """
208    if not self._steps or isinstance(self._steps[-1], self._TestStep):
209      raise RuntimeError('A build step is required before a test step.')
210    self._steps.append(self._TestStep(step_func))
211
212  def _exception_to_dict(self, exception: Exception):
213    """Converts an exception object to a dictionary to be saved by json."""
214    return {
215        'type': exception.__class__.__name__,
216        'message': str(exception),
217        'traceback': ''.join(traceback.format_tb(exception.__traceback__)),
218    }
219
220  def _dict_to_exception(self, exception_dict: dict[str, str]):
221    """Converts a dictionary to an exception object."""
222    return RuntimeError(
223        'The last build step raised an exception:\n'
224        f'{exception_dict["type"]}: {exception_dict["message"]}\n'
225        'Traceback (from saved snapshot):\n'
226        f'{exception_dict["traceback"]}'
227    )
228
229  def run(self):
230    """Run the steps added previously.
231
232    This function cannot be executed more than once.
233    Raises:
234        RuntimeError: When attempted to run the script multiple times.
235    """
236    if self._has_already_run:
237      raise RuntimeError(f'Script {self.name} has already run.')
238    self._has_already_run = True
239
240    build_step_exception_key = '_internal_build_step_exception'
241
242    for index, step in enumerate(self._steps):
243      if isinstance(step, self._BuildStep) and self.get_config().is_build_env:
244        env = os.environ
245        step_in = StepInput(
246            env,
247            self._get_repo_root(os.environ),
248            self.get_config(),
249            {},
250        )
251        last_exception = None
252        try:
253          step_out = step.get_step_func()(step_in)
254        # pylint: disable=broad-exception-caught
255        except Exception as e:
256          last_exception = e
257          step_out = StepOutput()
258          step_out.add_snapshot_obj(
259              build_step_exception_key, self._exception_to_dict(e)
260          )
261
262        self._take_snapshot(
263            self._get_repo_root(os.environ),
264            self._id + '_' + str(index // 2),
265            step_out,
266            env,
267        )
268
269        if last_exception:
270          raise last_exception
271
272      if isinstance(step, self._TestStep) and self.get_config().is_test_env:
273        env, objs = self._restore_snapshot(self._id + '_' + str(index // 2))
274
275        if build_step_exception_key in objs:
276          raise self._dict_to_exception(objs[build_step_exception_key])
277
278        step_in = StepInput(
279            env,
280            self._get_repo_root(env),
281            self.get_config(),
282            objs,
283        )
284        step.get_step_func()(step_in)
285
286  def add_snapshot_restore_exclude_paths(self, paths: list[str]) -> None:
287    """Add paths to ignore during snapshot directory restore."""
288    self._snapshot_restore_exclude_paths.extend(paths)
289
290  def _take_snapshot(
291      self,
292      repo_root: str,
293      name: str,
294      step_out: StepOutput,
295      env: dict[str, str],
296  ) -> None:
297    """Take a snapshot of the repository and environment."""
298    self._snapshot.take_snapshot(
299        name,
300        repo_root,
301        include_paths=step_out.get_snapshot_include_paths(),
302        exclude_paths=step_out.get_snapshot_exclude_paths(),
303        env_keys=step_out.get_snapshot_env_keys(),
304        env=env,
305        objs=step_out.get_snapshot_objs(),
306    )
307
308  def _restore_snapshot(self, name: str) -> None:
309    """Restore the repository and environment from a snapshot."""
310    return self._snapshot.restore_snapshot(
311        name,
312        self.get_config().workspace_path.as_posix(),
313        exclude_paths=self._snapshot_restore_exclude_paths,
314    )
315
316  def _get_repo_root(self, env) -> str:
317    """Get repo root directory."""
318    if self.get_config().is_build_env:
319      return os.environ[ANDROID_BUILD_TOP_KEY]
320    return env[ANDROID_BUILD_TOP_KEY]
321
322  class _Step:
323    """Parent class to build step and test step for typing declaration."""
324
325  class _BuildStep(_Step):
326
327    def __init__(self, step_func: Callable[StepInput, StepOutput]):
328      self._step_func = step_func
329
330    def get_step_func(self) -> Callable[StepInput, StepOutput]:
331      """Returns the stored step function for build."""
332      return self._step_func
333
334  class _TestStep(_Step):
335
336    def __init__(self, step_func: Callable[StepInput, None]):
337      self._step_func = step_func
338
339    def get_step_func(self) -> Callable[StepInput, None]:
340      """Returns the stored step function for test."""
341      return self._step_func
342
343
344class SplitBuildTestTestCase(unittest.TestCase):
345  """Base test case class for split build-test scripting tests."""
346
347  # Internal config to be injected to the test case from main.
348  _config: IntegrationTestConfiguration = None
349
350  @classmethod
351  def set_config(cls, config: IntegrationTestConfiguration) -> None:
352    cls._config = config
353
354  @classmethod
355  def get_config(cls) -> IntegrationTestConfiguration:
356    return cls._config
357
358  def create_split_build_test_script(
359      self, name: str = None
360  ) -> SplitBuildTestScript:
361    """Return an instance of SplitBuildTestScript with the given name.
362
363    Args:
364        name: The name of the script. The name will be used to store snapshots
365          and it's recommended to set the name to test id such as self.id().
366          Defaults to the test id if not set.
367    """
368    if not name:
369      name = self.id()
370      main_module_name = '__main__'
371      if name.startswith(main_module_name):
372        script_name = pathlib.Path(sys.modules[main_module_name].__file__).stem
373        name = name.replace(main_module_name, script_name)
374    return SplitBuildTestScript(name, self.get_config())
375
376
377class _FileCompressor:
378  """Class for compressing and decompressing files."""
379
380  def compress_all_sub_files(self, root_path: pathlib.Path) -> None:
381    """Compresses all files in the given directory and subdirectories.
382
383    Args:
384        root_path: The path to the root directory.
385    """
386    cpu_count = multiprocessing.cpu_count()
387    with concurrent.futures.ThreadPoolExecutor(
388        max_workers=cpu_count
389    ) as executor:
390      for file_path in root_path.rglob('*'):
391        if file_path.is_file():
392          executor.submit(self.compress_file, file_path)
393
394  def compress_file(self, file_path: pathlib.Path) -> None:
395    """Compresses a single file to zip.
396
397    Args:
398        file_path: The path to the file to compress.
399    """
400    with zipfile.ZipFile(
401        file_path.with_suffix('.zip'), 'w', zipfile.ZIP_DEFLATED
402    ) as zip_file:
403      zip_file.write(file_path, arcname=file_path.name)
404    file_path.unlink()
405
406  def decompress_all_sub_files(self, root_path: pathlib.Path) -> None:
407    """Decompresses all compressed sub files in the given directory.
408
409    Args:
410        root_path: The path to the root directory.
411    """
412    cpu_count = multiprocessing.cpu_count()
413    with concurrent.futures.ThreadPoolExecutor(
414        max_workers=cpu_count
415    ) as executor:
416      for file_path in root_path.rglob('*.zip'):
417        executor.submit(self.decompress_file, file_path)
418
419  def decompress_file(self, file_path: pathlib.Path) -> None:
420    """Decompresses a single zip file.
421
422    Args:
423        file_path: The path to the compressed file.
424    """
425    with zipfile.ZipFile(file_path, 'r') as zip_file:
426      zip_file.extractall(file_path.parent)
427    file_path.unlink()
428
429
430class ParallelTestRunner(unittest.TextTestRunner):
431  """A class that holds the logic of parallel test execution.
432
433  Test methods wrapped by decorators defined in this class will be pre-executed
434  at the beginning of the test run in parallel and have the results cached when
435  the test runner is also this class. Available decorators: `run_in_parallel`
436  for runnint test method in parallel during both build and test env,
437  `run_in_parallel_in_build_env` for parallel run in build env only, and
438  `run_in_parallel_in_test_env` for parallel run in test env only.
439  """
440
441  _RUN_IN_PARALLEL = 'run_in_parallel'
442  _RUN_IN_PARALLEL_IN_BUILD_ENV = 'run_in_parallel_in_build_env'
443  _RUN_IN_PARALLEL_IN_TEST_ENV = 'run_in_parallel_in_test_env'
444  _DECORATOR_NAME = 'decorator_name'
445
446  @classmethod
447  def _cache_first(
448      cls, func: Callable[[Any], Any], decorator_name: str
449  ) -> Callable[[Any], Any]:
450    """Cache a function's first call result and consumes it in the next call.
451
452    This decorator is similar to the built-in `functools.cache` decorator except
453    that this decorator caches the first call's run result and emit it in the
454    next run of the function, regardless of the function's input argument value
455    changes. Caching only the first call of the test ensures test retries emit
456    fresh results.
457
458    Args:
459        func: The function to cache.
460        decorator_name: The name of the decorator.
461
462    Returns:
463        The wrapped function with queue caching ability.
464    """
465    setattr(func, cls._DECORATOR_NAME, decorator_name)
466
467    class _ResultCache:
468      result = None
469      is_to_be_cached = False
470
471    result_cache = _ResultCache()
472
473    @functools.wraps(func)
474    def _wrapped(*args, only_set_next_run_caching=False, **kwargs):
475      if only_set_next_run_caching:
476        result_cache.is_to_be_cached = True
477        return
478
479      def _get_fresh_call_result():
480        try:
481          return (func(*args, **kwargs), None)
482        # pylint: disable-next=broad-exception-caught
483        except Exception as e:
484          return (None, e)
485
486      if result_cache.is_to_be_cached:
487        result = _get_fresh_call_result()
488        result_cache.result = result
489        result_cache.is_to_be_cached = False
490      elif result_cache.result:
491        result = result_cache.result
492        result_cache.result = None
493      else:
494        result = _get_fresh_call_result()
495      if result[1]:
496        raise result[1]
497      return result[0]
498
499    return _wrapped
500
501  @classmethod
502  def run_in_parallel(cls, func: Callable[[Any], Any]) -> Callable[[Any], Any]:
503    """Hint that a test method can run in parallel."""
504    return cls._cache_first(func, cls.run_in_parallel.__name__)
505
506  @classmethod
507  def run_in_parallel_in_build_env(
508      cls, func: Callable[[Any], Any]
509  ) -> Callable[[Any], Any]:
510    """Hint that a test method can run in parallel in build env only."""
511    return cls._cache_first(func, cls.run_in_parallel_in_build_env.__name__)
512
513  @classmethod
514  def run_in_parallel_in_test_env(
515      cls, func: Callable[[Any], Any]
516  ) -> Callable[[Any], Any]:
517    """Hint that a test method can run in parallel in test env only."""
518    return cls._cache_first(func, cls.run_in_parallel_in_test_env.__name__)
519
520  @classmethod
521  def setup_parallel(cls, func: Callable[[Any], Any]) -> Callable[[Any], Any]:
522    """Hint that a method is for setting up a parallel run."""
523    return cls._cache_first(func, cls.setup_parallel.__name__)
524
525  @classmethod
526  def setup_parallel_in_build_env(
527      cls, func: Callable[[Any], Any]
528  ) -> Callable[[Any], Any]:
529    """Hint that a method is for setting up a parallel run in build env only."""
530    return cls._cache_first(func, cls.setup_parallel_in_build_env.__name__)
531
532  @classmethod
533  def setup_parallel_in_test_env(
534      cls, func: Callable[[Any], Any]
535  ) -> Callable[[Any], Any]:
536    """Hint that a method is for setting up a parallel run in test env only."""
537    return cls._cache_first(func, cls.setup_parallel_in_test_env.__name__)
538
539  def run(self, test):
540    """Executes parallel tests first and then non-parallel tests."""
541    for test_suite in test:
542      self._pre_execute_parallel_tests(test_suite)
543    return super().run(test)
544
545  @staticmethod
546  def _get_test_function(test: unittest.TestCase) -> Callable[Any, Any]:
547    """Gets the test function from a TestCase class wrapped by unittest."""
548    return getattr(test, test.id().split('.')[-1])
549
550  @classmethod
551  def _get_parallel_setups(
552      cls, test_suite: unittest.TestSuite
553  ) -> set[Callable[None, Any]]:
554    """Returns a set of functions to be executed as setup for parallel run."""
555    test_cls = None
556    for test_case in test_suite:
557      test_cls = test_case.__class__
558      break
559    if not test_cls:
560      return set()
561
562    result = set()
563    update_result = lambda decorator: result.update(
564        filter(
565            lambda func: callable(func)
566            and decorator.__name__ == getattr(func, cls._DECORATOR_NAME, None),
567            map(functools.partial(getattr, test_cls), dir(test_cls)),
568        )
569    )
570    update_result(cls.setup_parallel)
571    if test_cls.get_config().is_build_env:
572      update_result(cls.setup_parallel_in_build_env)
573    if test_cls.get_config().is_test_env:
574      update_result(cls.setup_parallel_in_test_env)
575    return result
576
577  @classmethod
578  def _get_parallel_tests(
579      cls, test_suite: unittest.TestSuite
580  ) -> Iterator[unittest.TestCase]:
581    """Returns a list of test cases to be run in parallel from a test suite."""
582    and_combine = lambda *funcs: functools.reduce(
583        lambda accu, func: lambda item: accu(item) and func(item), funcs
584    )
585    or_combine = lambda *funcs: functools.reduce(
586        lambda accu, func: lambda item: accu(item) or func(item), funcs
587    )
588    is_decorated = lambda decorator, test: decorator.__name__ == getattr(
589        cls._get_test_function(test),
590        cls._DECORATOR_NAME,
591        None,
592    )
593    is_parallel = functools.partial(is_decorated, cls.run_in_parallel)
594    is_parallel_in_build = functools.partial(
595        is_decorated, cls.run_in_parallel_in_build_env
596    )
597    is_parallel_in_test = functools.partial(
598        is_decorated, cls.run_in_parallel_in_test_env
599    )
600    is_in_build_env = lambda test: test.get_config().is_build_env
601    is_in_test_env = lambda test: test.get_config().is_test_env
602    combined_filter = or_combine(
603        and_combine(is_parallel_in_build, is_in_build_env),
604        and_combine(is_parallel_in_test, is_in_test_env),
605        is_parallel,
606    )
607    return filter(combined_filter, test_suite)
608
609  @classmethod
610  def _pre_execute_parallel_tests(cls, test_suite: unittest.TestSuite) -> None:
611    """Pre-execute parallel tests in the test suite."""
612    for setup_func in cls._get_parallel_setups(test_suite):
613      logging.info('Setting up parallel tests with function %s', setup_func)
614      setup_func()
615    with concurrent.futures.ThreadPoolExecutor(
616        max_workers=multiprocessing.cpu_count()
617    ) as executor:
618
619      def _execute_test(test):
620        # We can't directly call test.run because the function would either not
621        # know that it's being pre-executed or not know whether it's being
622        # executed by this test runner. We can't call the test function directly
623        # because setup and teardown would be missed. We can't set properties
624        # of the test function here because the test function has already been
625        # wrapped by unittest. The only way we can let the test function know
626        # that it needs to cache the next run is to call the function with a
627        # parameter first before calling the run method.
628        cls._get_test_function(test).__func__(only_set_next_run_caching=True)
629        return executor.submit(test.run)
630
631      for class_name, class_group in itertools.groupby(
632          cls._get_parallel_tests(test_suite),
633          lambda obj: f'{obj.__class__.__module__}.{obj.__class__}',
634      ):
635        test_group = list(class_group)
636        logging.info(
637            'Pre-executing %s of %s tests in parallel...',
638            len(test_group),
639            class_name,
640        )
641
642        list(concurrent.futures.as_completed(map(_execute_test, test_group)))
643
644
645def _configure_logging(verbose: bool, log_file_dir_path: pathlib.Path):
646  """Configure the logger.
647
648  Args:
649      verbose: If true display DEBUG level logs on console.
650      log_file_dir_path: A directory which stores the log file.
651  """
652  log_file = log_file_dir_path.joinpath('asuite_integration_tests.log')
653  if log_file.exists():
654    timestamp = datetime.datetime.now().strftime('%Y-%m-%d_%H:%M:%S')
655    log_file = log_file_dir_path.joinpath(
656        f'asuite_integration_tests_{timestamp}.log'
657    )
658  log_file.parent.mkdir(parents=True, exist_ok=True)
659
660  atexit.register(lambda: print('Logs are saved to %s' % log_file))
661
662  log_format = '%(asctime)s %(filename)s:%(lineno)s:%(levelname)s: %(message)s'
663  date_format = '%Y-%m-%d %H:%M:%S'
664  logging.basicConfig(
665      filename=log_file.as_posix(),
666      level=logging.DEBUG,
667      format=log_format,
668      datefmt=date_format,
669  )
670  console = logging.StreamHandler()
671  console.name = 'console'
672  console.setLevel(logging.INFO)
673  if verbose:
674    console.setLevel(logging.DEBUG)
675  console.setFormatter(logging.Formatter(log_format))
676  logging.getLogger('').addHandler(console)
677
678
679@dataclasses.dataclass
680class AddArgument:
681  """A class to add an argument to the argparse parser and copy to test config."""
682
683  dest: str
684  args: tuple[Any, ...]
685  kwargs: dict[str, Any]
686
687  def __init__(self, dest: str, *args: Any, **kwargs: Any) -> None:
688    """Initializes the AddArgument class.
689
690    Params:
691        dest: Specify the attribute name used in the result namespace. This is
692          required here for adding the parsed value to test config object.
693        *args: Any arguments used to call argparse.add_argument.
694        **kwargs: Any keyword arguments used to call argparse.add_argument.
695    """
696    self.dest = dest
697    self.args = args
698    self.kwargs = kwargs
699    self.kwargs['dest'] = dest
700
701
702def _parse_known_args(
703    argv: list[str],
704    additional_args: list[AddArgument],
705) -> tuple[argparse.Namespace, list[str]]:
706  """Parse command line args and check required args being provided."""
707
708  description = """A script to build and/or run the Asuite integration tests.
709Usage examples:
710   python <script_path>: Runs both the build and test steps.
711   python <script_path> -b -t: Runs both the build and test steps.
712   python <script_path> -b: Runs only the build steps.
713   python <script_path> -t: Runs only the test steps.
714"""
715
716  parser = argparse.ArgumentParser(
717      add_help=True,
718      description=description,
719      formatter_class=argparse.RawDescriptionHelpFormatter,
720  )
721
722  parser.add_argument(
723      '-b',
724      '--build',
725      action='store_true',
726      default=False,
727      help=(
728          'Run build steps. Can be set to true together with the test option.'
729          ' If both build and test are unset, will run both steps.'
730      ),
731  )
732  parser.add_argument(
733      '-t',
734      '--test',
735      action='store_true',
736      default=False,
737      help=(
738          'Run test steps. Can be set to true together with the build option.'
739          ' If both build and test are unset, will run both steps.'
740      ),
741  )
742  parser.add_argument(
743      '--tar_snapshot',
744      action='store_true',
745      default=False,
746      help=(
747          'Whether to tar and untar the snapshot storage into/from a single'
748          ' file.'
749      ),
750  )
751  parser.add_argument(
752      '-v',
753      '--verbose',
754      action='store_true',
755      default=False,
756      help='Whether to set log level to verbose.',
757  )
758
759  # The below flags are passed in by the TF Python test runner.
760  parser.add_argument(
761      '-s',
762      '--serial',
763      help=(
764          'The device serial. Required in test mode when ANDROID_BUILD_TOP is'
765          ' not set.'
766      ),
767  )
768  parser.add_argument(
769      '--test-output-file',
770      help=(
771          'The file in which to store the unit test results. This option is'
772          ' usually set by TradeFed when running the script with python and'
773          ' is optional during manual script execution.'
774      ),
775  )
776
777  for additional_arg in additional_args:
778    parser.add_argument(*additional_arg.args, **additional_arg.kwargs)
779
780  return parser.parse_known_args(argv)
781
782
783def _run_test(
784    config: IntegrationTestConfiguration,
785    argv: list[str],
786    test_output_file_path: str = None,
787) -> None:
788  """Execute integration tests with given test configuration."""
789
790  compressor = _FileCompressor()
791
792  def cleanup() -> None:
793    if config.workspace_path.exists():
794      shutil.rmtree(config.workspace_path)
795    if config.snapshot_storage_path.exists():
796      shutil.rmtree(config.snapshot_storage_path)
797
798  if config.is_test_env and config.is_tar_snapshot:
799    if not config.snapshot_storage_tar_path.exists():
800      raise EnvironmentError(
801          f'Snapshot tar {config.snapshot_storage_tar_path} does not'
802          ' exist. Have you run the build mode with --tar_snapshot'
803          ' option enabled?'
804      )
805    logging.info(
806        'Extracting tar file %s',
807        config.snapshot_storage_tar_path,
808    )
809    with tarfile.open(config.snapshot_storage_tar_path, 'r') as tar:
810      tar.extractall(config.snapshot_storage_path.parent.as_posix())
811    logging.info('Done extracting tar file')
812
813    logging.info(
814        'Decompressing the snapshot storage with %s threads...',
815        multiprocessing.cpu_count(),
816    )
817    start_time = time.time()
818    compressor.decompress_all_sub_files(config.snapshot_storage_path)
819    logging.info(
820        'Decompression finished in {:.2f} seconds'.format(
821            time.time() - start_time
822        )
823    )
824
825    atexit.register(cleanup)
826
827  def unittest_main(stream=None):
828    # Note that we use a type and not an instance for 'testRunner'
829    # since TestProgram forwards its constructor arguments when creating
830    # an instance of the runner type. Not doing so would require us to
831    # make sure that the parameters passed to TestProgram are aligned
832    # with those for creating a runner instance.
833    class TestRunner(ParallelTestRunner):
834      """Writes test results to the TF-provided file."""
835
836      def __init__(self, *args: Any, **kwargs: Any) -> None:
837        super().__init__(stream=stream, *args, **kwargs)
838
839    class TestLoader(unittest.TestLoader):
840      """Injects the test configuration to the test classes."""
841
842      def loadTestsFromTestCase(self, *args, **kwargs):
843        test_suite = super().loadTestsFromTestCase(*args, **kwargs)
844        for test in test_suite:
845          test.__class__.set_config(config)
846          break
847        return test_suite
848
849    # Setting verbosity is required to generate output that the TradeFed
850    # test runner can parse.
851    unittest.main(
852        testRunner=TestRunner,
853        verbosity=3,
854        argv=argv,
855        testLoader=TestLoader(),
856        exit=config.is_test_env,
857    )
858
859  if test_output_file_path:
860    pathlib.Path(test_output_file_path).parent.mkdir(exist_ok=True)
861
862    with open(test_output_file_path, 'w', encoding='utf-8') as test_output_file:
863      unittest_main(stream=test_output_file)
864  else:
865    unittest_main(stream=None)
866
867  if config.is_build_env and config.is_tar_snapshot:
868    logging.info(
869        'Compressing the snapshot storage with %s threads...',
870        multiprocessing.cpu_count(),
871    )
872    start_time = time.time()
873    compressor.compress_all_sub_files(config.snapshot_storage_path)
874    logging.info(
875        'Compression finished in {:.2f} seconds'.format(
876            time.time() - start_time
877        )
878    )
879
880    with tarfile.open(config.snapshot_storage_tar_path, 'w') as tar:
881      tar.add(
882          config.snapshot_storage_path,
883          arcname=config.snapshot_storage_path.name,
884      )
885    cleanup()
886
887
888def main(
889    argv: list[str] = None,
890    make_before_build: list[str] = None,
891    additional_args: list[AddArgument] = None,
892) -> None:
893  """Main method to start the integration tests.
894
895  Args:
896      argv: A list of arguments to parse.
897      make_before_build: A list of targets to make before running build steps.
898      additional_args: A list of additional arguments to be injected to the
899        argparser and test config.
900
901  Raises:
902      EnvironmentError: When some environment variables are missing.
903  """
904  if not argv:
905    argv = sys.argv
906  if make_before_build is None:
907    make_before_build = []
908  if additional_args is None:
909    additional_args = []
910
911  args, unittest_argv = _parse_known_args(argv, additional_args)
912
913  snapshot_storage_dir_name = 'snapshot_storage'
914  snapshot_storage_tar_name = 'snapshot.tar'
915
916  integration_test_out_path = pathlib.Path(
917      tempfile.gettempdir(),
918      'asuite_integration_tests_%s'
919      % pathlib.Path('~').expanduser().name.replace(' ', '_'),
920  )
921
922  if SNAPSHOT_STORAGE_TAR_KEY in os.environ:
923    snapshot_storage_tar_path = pathlib.Path(
924        os.environ[SNAPSHOT_STORAGE_TAR_KEY]
925    )
926    snapshot_storage_tar_path.parent.mkdir(parents=True, exist_ok=True)
927  else:
928    snapshot_storage_tar_path = integration_test_out_path.joinpath(
929        snapshot_storage_tar_name
930    )
931
932  _configure_logging(args.verbose, snapshot_storage_tar_path.parent)
933
934  logging.debug('The os environ is: %s', os.environ)
935
936  # When the build or test is unset, assume it's a local run for both build
937  # and test steps.
938  is_build_test_unset = not args.build and not args.test
939  config = IntegrationTestConfiguration()
940  config.is_build_env = args.build or is_build_test_unset
941  config.is_test_env = args.test or is_build_test_unset
942  config.device_serial = args.serial
943  config.snapshot_storage_path = integration_test_out_path.joinpath(
944      snapshot_storage_dir_name
945  )
946  config.snapshot_storage_tar_path = snapshot_storage_tar_path
947  config.workspace_path = integration_test_out_path.joinpath('workspace')
948  config.is_tar_snapshot = args.tar_snapshot
949  for additional_arg in additional_args:
950    setattr(config, additional_arg.dest, getattr(args, additional_arg.dest))
951
952  if config.is_build_env:
953    if ANDROID_BUILD_TOP_KEY not in os.environ:
954      raise EnvironmentError(
955          f'Environment variable {ANDROID_BUILD_TOP_KEY} is required to'
956          ' build the integration test.'
957      )
958
959    repo_root = os.environ[ANDROID_BUILD_TOP_KEY]
960
961    total, used, free = shutil.disk_usage(repo_root)
962    logging.debug(
963        'Disk usage: Total: {:.2f} GB, Used: {:.2f} GB, Free: {:.2f} GB'.format(
964            total / (1024**3), used / (1024**3), free / (1024**3)
965        )
966    )
967
968    if 'OUT_DIR' in os.environ:
969      out_dir = os.environ['OUT_DIR']
970      if os.path.isabs(out_dir) and not pathlib.Path(out_dir).is_relative_to(
971          repo_root
972      ):
973        raise EnvironmentError(
974            f'$OUT_DIR {out_dir} not relative to the repo root'
975            f' {repo_root} is not supported yet.'
976        )
977    elif 'HOST_OUT' in os.environ:
978      out_dir = (
979          pathlib.Path(os.environ['HOST_OUT']).relative_to(repo_root).parts[0]
980      )
981    else:
982      out_dir = 'out'
983    os.environ['OUT_DIR'] = out_dir
984
985    for target in make_before_build:
986      logging.info(
987          'Building the %s target before integration test run.', target
988      )
989      subprocess.check_call(
990          f'build/soong/soong_ui.bash --make-mode {target}'.split(),
991          cwd=repo_root,
992      )
993
994  if config.is_build_env ^ config.is_test_env:
995    _run_test(config, unittest_argv, args.test_output_file)
996    return
997
998  build_config = copy.deepcopy(config)
999  build_config.is_test_env = False
1000
1001  test_config = copy.deepcopy(config)
1002  test_config.is_build_env = False
1003
1004  _run_test(build_config, unittest_argv, args.test_output_file)
1005  _run_test(test_config, unittest_argv, args.test_output_file)
1006