xref: /aosp_15_r20/external/cronet/base/task/thread_pool/test_utils.h (revision 6777b5387eb2ff775bb5750e3f5d96f37fb7352b)
1 // Copyright 2016 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #ifndef BASE_TASK_THREAD_POOL_TEST_UTILS_H_
6 #define BASE_TASK_THREAD_POOL_TEST_UTILS_H_
7 
8 #include <atomic>
9 
10 #include "base/functional/callback.h"
11 #include "base/memory/raw_ptr.h"
12 #include "base/task/common/checked_lock.h"
13 #include "base/task/post_job.h"
14 #include "base/task/task_features.h"
15 #include "base/task/task_runner.h"
16 #include "base/task/task_traits.h"
17 #include "base/task/thread_pool/delayed_task_manager.h"
18 #include "base/task/thread_pool/pooled_task_runner_delegate.h"
19 #include "base/task/thread_pool/sequence.h"
20 #include "base/task/thread_pool/task_tracker.h"
21 #include "base/task/thread_pool/thread_group.h"
22 #include "base/task/thread_pool/worker_thread_observer.h"
23 #include "build/build_config.h"
24 #include "testing/gmock/include/gmock/gmock.h"
25 #include "third_party/abseil-cpp/absl/types/variant.h"
26 
27 namespace base {
28 namespace internal {
29 
30 struct Task;
31 
32 namespace test {
33 
34 class MockWorkerThreadObserver : public WorkerThreadObserver {
35  public:
36   MockWorkerThreadObserver();
37   MockWorkerThreadObserver(const MockWorkerThreadObserver&) = delete;
38   MockWorkerThreadObserver& operator=(const MockWorkerThreadObserver&) = delete;
39   ~MockWorkerThreadObserver() override;
40 
41   void AllowCallsOnMainExit(int num_calls);
42   void WaitCallsOnMainExit();
43 
44   // WorkerThreadObserver:
45   MOCK_METHOD0(OnWorkerThreadMainEntry, void());
46   // This doesn't use MOCK_METHOD0 because some tests need to wait for all calls
47   // to happen, which isn't possible with gmock.
48   void OnWorkerThreadMainExit() override;
49 
50  private:
51   CheckedLock lock_;
52   ConditionVariable on_main_exit_cv_ GUARDED_BY(lock_);
53   int allowed_calls_on_main_exit_ GUARDED_BY(lock_) = 0;
54 };
55 
56 class MockPooledTaskRunnerDelegate : public PooledTaskRunnerDelegate {
57  public:
58   MockPooledTaskRunnerDelegate(TrackedRef<TaskTracker> task_tracker,
59                                DelayedTaskManager* delayed_task_manager);
60   ~MockPooledTaskRunnerDelegate() override;
61 
62   // PooledTaskRunnerDelegate:
63   bool PostTaskWithSequence(Task task,
64                             scoped_refptr<Sequence> sequence) override;
65   bool EnqueueJobTaskSource(scoped_refptr<JobTaskSource> task_source) override;
66   void RemoveJobTaskSource(scoped_refptr<JobTaskSource> task_source) override;
67   bool ShouldYield(const TaskSource* task_source) override;
68   void UpdatePriority(scoped_refptr<TaskSource> task_source,
69                       TaskPriority priority) override;
70   void UpdateJobPriority(scoped_refptr<TaskSource> task_source,
71                          TaskPriority priority) override;
72 
73   void SetThreadGroup(ThreadGroup* thread_group);
74 
75   void PostTaskWithSequenceNow(Task task, scoped_refptr<Sequence> sequence);
76 
77  private:
78   const TrackedRef<TaskTracker> task_tracker_;
79   const raw_ptr<DelayedTaskManager> delayed_task_manager_;
80   raw_ptr<ThreadGroup> thread_group_ = nullptr;
81 };
82 
83 // A simple MockJobTask that will give |worker_task| a fixed number of times,
84 // possibly in parallel.
85 class MockJobTask : public base::RefCountedThreadSafe<MockJobTask> {
86  public:
87   // Gives |worker_task| to requesting workers |num_tasks_to_run| times.
88   MockJobTask(RepeatingCallback<void(JobDelegate*)> worker_task,
89               size_t num_tasks_to_run);
90 
91   // Gives |worker_task| to a single requesting worker.
92   explicit MockJobTask(base::OnceClosure worker_task);
93 
94   MockJobTask(const MockJobTask&) = delete;
95   MockJobTask& operator=(const MockJobTask&) = delete;
96 
97   // Updates the remaining number of time |worker_task| runs to
98   // |num_tasks_to_run|.
99   void SetNumTasksToRun(size_t num_tasks_to_run);
100 
101   size_t GetMaxConcurrency(size_t worker_count) const;
102   void Run(JobDelegate* delegate);
103 
104   scoped_refptr<JobTaskSource> GetJobTaskSource(
105       const Location& from_here,
106       const TaskTraits& traits,
107       PooledTaskRunnerDelegate* delegate);
108 
109  private:
110   friend class base::RefCountedThreadSafe<MockJobTask>;
111 
112   ~MockJobTask();
113 
114   absl::variant<OnceClosure, RepeatingCallback<void(JobDelegate*)>> task_;
115   std::atomic_size_t remaining_num_tasks_to_run_;
116 };
117 
118 // Creates a Sequence with given |traits| and pushes |task| to it. If a
119 // TaskRunner is associated with |task|, it should be be passed as |task_runner|
120 // along with its |execution_mode|. Returns the created Sequence.
121 scoped_refptr<Sequence> CreateSequenceWithTask(
122     Task task,
123     const TaskTraits& traits,
124     scoped_refptr<SequencedTaskRunner> task_runner = nullptr,
125     TaskSourceExecutionMode execution_mode =
126         TaskSourceExecutionMode::kParallel);
127 
128 // Creates a TaskRunner that posts tasks to the thread group owned by
129 // |pooled_task_runner_delegate| with the |execution_mode|.
130 // Caveat: this does not support TaskSourceExecutionMode::kSingleThread.
131 scoped_refptr<TaskRunner> CreatePooledTaskRunnerWithExecutionMode(
132     TaskSourceExecutionMode execution_mode,
133     MockPooledTaskRunnerDelegate* mock_pooled_task_runner_delegate,
134     const TaskTraits& traits = {});
135 
136 scoped_refptr<TaskRunner> CreatePooledTaskRunner(
137     const TaskTraits& traits,
138     MockPooledTaskRunnerDelegate* mock_pooled_task_runner_delegate);
139 
140 scoped_refptr<SequencedTaskRunner> CreatePooledSequencedTaskRunner(
141     const TaskTraits& traits,
142     MockPooledTaskRunnerDelegate* mock_pooled_task_runner_delegate);
143 
144 RegisteredTaskSource QueueAndRunTaskSource(
145     TaskTracker* task_tracker,
146     scoped_refptr<TaskSource> task_source);
147 
148 // Calls StartShutdown() and CompleteShutdown() on |task_tracker|.
149 void ShutdownTaskTracker(TaskTracker* task_tracker);
150 
151 }  // namespace test
152 }  // namespace internal
153 }  // namespace base
154 
155 #endif  // BASE_TASK_THREAD_POOL_TEST_UTILS_H_
156