// Copyright 2022 The ChromiumOS Authors // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. //! Multi-thread worker. #![deny(missing_docs)] use std::collections::VecDeque; use std::sync::atomic::AtomicBool; use std::sync::atomic::Ordering; use std::sync::Arc; use std::thread; use std::time::Duration; use anyhow::Context; use base::error; use base::Event; use base::EventWaitResult; use sync::Condvar; use sync::Mutex; /// Task to run on the worker threads. pub trait Task { /// Executes the task. fn execute(self); } /// Multi thread based worker executing a single type [Task]. /// /// See the doc of [Channel] as well for the behaviors of it. pub struct Worker { /// Shared [Channel] with the worker threads. pub channel: Arc>, handles: Vec>, } impl Worker { /// Spawns the numbers of worker threads. pub fn new(len_channel: usize, n_workers: usize) -> Self { let channel = Arc::new(Channel::::new(len_channel, n_workers)); let mut handles = Vec::with_capacity(n_workers); for _ in 0..n_workers { let context = channel.clone(); let handle = thread::spawn(move || { Self::worker_thread(context); }); handles.push(handle); } Self { channel, handles } } fn worker_thread(context: Arc>) { while let Some(task) = context.pop() { task.execute(); } } /// Closes the channel and wait for worker threads shutdown. /// /// This also waits for all the tasks in the channel to be executed. pub fn close(self) { self.channel.close(); for handle in self.handles { match handle.join() { Ok(()) => {} Err(e) => { error!("failed to wait for worker thread: {:?}", e); } } } } } /// MPMC (Multi Producers Multi Consumers) queue integrated with [Worker]. /// /// [Channel] offers [Channel::wait_complete()] to guarantee all the tasks are executed. /// /// This only exposes methods for producers. pub struct Channel { state: Mutex>, consumer_wait: Condvar, producer_wait: Condvar, n_consumers: usize, } impl Channel { fn new(len: usize, n_consumers: usize) -> Self { Self { state: Mutex::new(ChannelState::new(len)), consumer_wait: Condvar::new(), producer_wait: Condvar::new(), n_consumers, } } fn close(&self) { let mut state = self.state.lock(); state.is_closed = true; self.consumer_wait.notify_all(); self.producer_wait.notify_all(); } /// Pops a task from the channel. /// /// If the queue is closed and also **empty**, this returns [None]. This returns all the tasks /// in the queue even while this is closed. #[inline] fn pop(&self) -> Option { let mut state = self.state.lock(); loop { let was_full = state.queue.len() == state.capacity; if let Some(item) = state.queue.pop_front() { if was_full { // notification for a producer waiting for `push()`. self.producer_wait.notify_one(); } return Some(item); } else { if state.is_closed { return None; } state.n_waiting += 1; if state.n_waiting == self.n_consumers { // notification for producers waiting for `wait_complete()`. self.producer_wait.notify_all(); } state = self.consumer_wait.wait(state); state.n_waiting -= 1; } } } /// Push a task. /// /// This blocks if the channel is full. /// /// If the channel is closed, this returns `false`. pub fn push(&self, item: T) -> bool { let mut state = self.state.lock(); // Wait until the queue has room to push a task. while state.queue.len() == state.capacity { if state.is_closed { return false; } state = self.producer_wait.wait(state); } if state.is_closed { return false; } state.queue.push_back(item); self.consumer_wait.notify_one(); true } /// Wait until all the tasks have been executed. /// /// This guarantees that all the tasks in this channel are not only consumed but also executed. pub fn wait_complete(&self) { let mut state = self.state.lock(); while !(state.queue.is_empty() && state.n_waiting == self.n_consumers) { state = self.producer_wait.wait(state); } } } struct ChannelState { queue: VecDeque, capacity: usize, n_waiting: usize, is_closed: bool, } impl ChannelState { fn new(capacity: usize) -> Self { Self { queue: VecDeque::with_capacity(capacity), capacity, n_waiting: 0, is_closed: false, } } } /// The event channel for background jobs. /// /// This sends an abort request from the main thread to the job thread via atomic boolean flag. /// /// This notifies the main thread that the job thread is completed via [Event]. pub struct BackgroundJobControl { event: Event, abort_flag: AtomicBool, } impl BackgroundJobControl { /// Creates [BackgroundJobControl]. pub fn new() -> anyhow::Result { Ok(Self { event: Event::new()?, abort_flag: AtomicBool::new(false), }) } /// Creates [BackgroundJob]. pub fn new_job(&self) -> BackgroundJob<'_> { BackgroundJob { event: &self.event, abort_flag: &self.abort_flag, } } /// Abort the background job. pub fn abort(&self) { self.abort_flag.store(true, Ordering::Release); } /// Reset the internal state for a next job. /// /// Returns false, if the event is already reset and no event exists. pub fn reset(&self) -> anyhow::Result { self.abort_flag.store(false, Ordering::Release); Ok(matches!( self.event .wait_timeout(Duration::ZERO) .context("failed to get job complete event")?, EventWaitResult::Signaled )) } /// Returns the event to notify the completion of background job. pub fn get_completion_event(&self) -> &Event { &self.event } } /// Background job context. /// /// When dropped, this sends an event to the main thread via [Event]. pub struct BackgroundJob<'a> { event: &'a Event, abort_flag: &'a AtomicBool, } impl BackgroundJob<'_> { /// Returns whether the background job is aborted or not. pub fn is_aborted(&self) -> bool { self.abort_flag.load(Ordering::Acquire) } } impl Drop for BackgroundJob<'_> { fn drop(&mut self) { self.event.signal().expect("send job complete event"); } } #[cfg(test)] mod tests { use std::time::Duration; use super::*; #[derive(Clone, Copy)] struct Context { n_consume: usize, n_executed: usize, } struct FakeTask { context: Mutex, waker: Condvar, } impl FakeTask { fn new() -> Arc { Arc::new(Self { context: Mutex::new(Context { n_consume: 0, n_executed: 0, }), waker: Condvar::new(), }) } fn consume(&self, count: usize) { let mut context = self.context.lock(); context.n_consume += count; self.waker.notify_all(); } fn n_executed(&self) -> usize { self.context.lock().n_executed } } impl Task for Arc { fn execute(self) { let mut context = self.context.lock(); while context.n_consume == 0 { context = self.waker.wait(context); } context.n_consume -= 1; context.n_executed += 1; } } fn wait_thread_with_timeout(join_handle: thread::JoinHandle, timeout_millis: u64) -> T { for _ in 0..timeout_millis { if join_handle.is_finished() { return join_handle.join().unwrap(); } thread::sleep(Duration::from_millis(1)); } panic!("thread join timeout"); } fn poll_until_with_timeout(f: F, timeout_millis: u64) where F: Fn() -> bool, { for _ in 0..timeout_millis { if f() { break; } thread::sleep(Duration::from_millis(1)); } } #[test] fn test_worker() { let worker = Worker::new(2, 4); let task = FakeTask::new(); let channel = worker.channel.clone(); for _ in 0..4 { assert!(channel.push(task.clone())); } assert_eq!(task.n_executed(), 0); task.consume(4); worker.channel.wait_complete(); assert_eq!(task.n_executed(), 4); worker.close(); } #[test] fn test_worker_push_after_close() { let worker = Worker::new(2, 4); let task = FakeTask::new(); let channel = worker.channel.clone(); worker.close(); assert!(!channel.push(task)); } #[test] fn test_worker_push_block() { let worker = Worker::new(2, 4); let task = FakeTask::new(); let channel = worker.channel.clone(); let task_cloned = task.clone(); // push tasks on another thread to avoid blocking forever wait_thread_with_timeout( thread::spawn(move || { for _ in 0..6 { assert!(channel.push(task_cloned.clone())); } }), 100, ); let channel = worker.channel.clone(); let task_cloned = task.clone(); let push_thread = thread::spawn(move || { assert!(channel.push(task_cloned)); }); thread::sleep(Duration::from_millis(10)); assert!(!push_thread.is_finished()); task.consume(1); wait_thread_with_timeout(push_thread, 100); task.consume(6); #[allow(clippy::redundant_clone)] let task_clone = task.clone(); poll_until_with_timeout(|| task_clone.n_executed() == 7, 100); assert_eq!(task.n_executed(), 7); worker.close(); } #[test] fn test_worker_close_on_push_blocked() { let worker = Worker::new(2, 4); let task = FakeTask::new(); let channel = worker.channel.clone(); let task_cloned = task.clone(); // push tasks on another thread to avoid blocking forever wait_thread_with_timeout( thread::spawn(move || { for _ in 0..6 { assert!(channel.push(task_cloned.clone())); } }), 100, ); let channel = worker.channel.clone(); let task_cloned = task.clone(); let push_thread = thread::spawn(move || channel.push(task_cloned)); // sleep to run push_thread. thread::sleep(Duration::from_millis(10)); // close blocks until all the task are executed. let close_thread = thread::spawn(move || { worker.close(); }); let push_result = wait_thread_with_timeout(push_thread, 100); // push fails. assert!(!push_result); // cleanup task.consume(6); wait_thread_with_timeout(close_thread, 100); } #[test] fn new_background_job_event() { assert!(BackgroundJobControl::new().is_ok()); } #[test] fn background_job_is_not_aborted_default() { let event = BackgroundJobControl::new().unwrap(); let job = event.new_job(); assert!(!job.is_aborted()); } #[test] fn abort_background_job() { let event = BackgroundJobControl::new().unwrap(); let job = event.new_job(); event.abort(); assert!(job.is_aborted()); } #[test] fn reset_background_job() { let event = BackgroundJobControl::new().unwrap(); event.abort(); event.reset().unwrap(); let job = event.new_job(); assert!(!job.is_aborted()); } #[test] fn reset_background_job_event() { let event = BackgroundJobControl::new().unwrap(); let job = event.new_job(); drop(job); assert!(event.reset().unwrap()); } #[test] fn reset_background_job_event_twice() { let event = BackgroundJobControl::new().unwrap(); let job = event.new_job(); drop(job); event.reset().unwrap(); assert!(!event.reset().unwrap()); } #[test] fn reset_background_job_event_no_jobs() { let event = BackgroundJobControl::new().unwrap(); assert!(!event.reset().unwrap()); } }