1#!/usr/bin/env python3
2#
3#   Copyright 2021 - The Android Open Source Project
4#
5#   Licensed under the Apache License, Version 2.0 (the "License");
6#   you may not use this file except in compliance with the License.
7#   You may obtain a copy of the License at
8#
9#       http://www.apache.org/licenses/LICENSE-2.0
10#
11#   Unless required by applicable law or agreed to in writing, software
12#   distributed under the License is distributed on an "AS IS" BASIS,
13#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14#   See the License for the specific language governing permissions and
15#   limitations under the License.
16
17import concurrent.futures
18import logging
19
20def task_wrapper(task):
21    """Task wrapper for multithread_func
22
23    Args:
24        task[0]: function to be wrapped.
25        task[1]: function args.
26
27    Returns:
28        Return value of wrapped function call.
29    """
30    func = task[0]
31    params = task[1]
32    return func(*params)
33
34
35def run_multithread_func_async(log, task):
36    """Starts a multi-threaded function asynchronously.
37
38    Args:
39        log: log object.
40        task: a task to be executed in parallel.
41
42    Returns:
43        Future object representing the execution of the task.
44    """
45    executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
46    try:
47        future_object = executor.submit(task_wrapper, task)
48    except Exception as e:
49        log.error("Exception error %s", e)
50        raise
51    return future_object
52
53
54def run_multithread_func(log, tasks):
55    """Run multi-thread functions and return results.
56
57    Args:
58        log: log object.
59        tasks: a list of tasks to be executed in parallel.
60
61    Returns:
62        results for tasks.
63    """
64    MAX_NUMBER_OF_WORKERS = 10
65    number_of_workers = min(MAX_NUMBER_OF_WORKERS, len(tasks))
66    executor = concurrent.futures.ThreadPoolExecutor(
67        max_workers=number_of_workers)
68    if not log: log = logging
69    try:
70        results = list(executor.map(task_wrapper, tasks))
71    except Exception as e:
72        log.error("Exception error %s", e)
73        raise
74    executor.shutdown()
75    if log:
76        log.info("multithread_func %s result: %s",
77                 [task[0].__name__ for task in tasks], results)
78    return results
79
80
81def multithread_func(log, tasks):
82    """Multi-thread function wrapper.
83
84    Args:
85        log: log object.
86        tasks: tasks to be executed in parallel.
87
88    Returns:
89        True if all tasks return True.
90        False if any task return False.
91    """
92    results = run_multithread_func(log, tasks)
93    for r in results:
94        if not r:
95            return False
96    return True
97
98
99def multithread_func_and_check_results(log, tasks, expected_results):
100    """Multi-thread function wrapper.
101
102    Args:
103        log: log object.
104        tasks: tasks to be executed in parallel.
105        expected_results: check if the results from tasks match expected_results.
106
107    Returns:
108        True if expected_results are met.
109        False if expected_results are not met.
110    """
111    return_value = True
112    results = run_multithread_func(log, tasks)
113    log.info("multithread_func result: %s, expecting %s", results,
114             expected_results)
115    for task, result, expected_result in zip(tasks, results, expected_results):
116        if result != expected_result:
117            logging.info("Result for task %s is %s, expecting %s", task[0],
118                         result, expected_result)
119            return_value = False
120    return return_value
121