1 package kotlinx.coroutines.test
2 
3 import kotlinx.atomicfu.*
4 import kotlinx.coroutines.*
5 import kotlinx.coroutines.channels.*
6 import kotlinx.coroutines.channels.Channel.Factory.CONFLATED
7 import kotlinx.coroutines.internal.*
8 import kotlinx.coroutines.selects.*
9 import kotlin.coroutines.*
10 import kotlin.jvm.*
11 import kotlin.time.*
12 import kotlin.time.Duration.Companion.milliseconds
13 
14 /**
15  * This is a scheduler for coroutines used in tests, providing the delay-skipping behavior.
16  *
17  * [Test dispatchers][TestDispatcher] are parameterized with a scheduler. Several dispatchers can share the
18  * same scheduler, in which case their knowledge about the virtual time will be synchronized. When the dispatchers
19  * require scheduling an event at a later point in time, they notify the scheduler, which will establish the order of
20  * the tasks.
21  *
22  * The scheduler can be queried to advance the time (via [advanceTimeBy]), run all the scheduled tasks advancing the
23  * virtual time as needed (via [advanceUntilIdle]), or run the tasks that are scheduled to run as soon as possible but
24  * haven't yet been dispatched (via [runCurrent]).
25  */
26 public class TestCoroutineScheduler : AbstractCoroutineContextElement(TestCoroutineScheduler),
27     CoroutineContext.Element {
28 
29     /** @suppress */
30     public companion object Key : CoroutineContext.Key<TestCoroutineScheduler>
31 
32     /** This heap stores the knowledge about which dispatchers are interested in which moments of virtual time. */
33     // TODO: all the synchronization is done via a separate lock, so a non-thread-safe priority queue can be used.
34     private val events = ThreadSafeHeap<TestDispatchEvent<Any>>()
35 
36     /** Establishes that [currentTime] can't exceed the time of the earliest event in [events]. */
37     private val lock = SynchronizedObject()
38 
39     /** This counter establishes some order on the events that happen at the same virtual time. */
40     private val count = atomic(0L)
41 
42     /** The current virtual time in milliseconds. */
43     @ExperimentalCoroutinesApi
44     public var currentTime: Long = 0
<lambda>null45         get() = synchronized(lock) { field }
46         private set
47 
48     /** A channel for notifying about the fact that a foreground work dispatch recently happened. */
49     private val dispatchEventsForeground: Channel<Unit> = Channel(CONFLATED)
50 
51     /** A channel for notifying about the fact that a dispatch recently happened. */
52     private val dispatchEvents: Channel<Unit> = Channel(CONFLATED)
53 
54     /**
55      * Registers a request for the scheduler to notify [dispatcher] at a virtual moment [timeDeltaMillis] milliseconds
56      * later via [TestDispatcher.processEvent], which will be called with the provided [marker] object.
57      *
58      * Returns the handler which can be used to cancel the registration.
59      */
registerEventnull60     internal fun <T : Any> registerEvent(
61         dispatcher: TestDispatcher,
62         timeDeltaMillis: Long,
63         marker: T,
64         context: CoroutineContext,
65         isCancelled: (T) -> Boolean
66     ): DisposableHandle {
67         require(timeDeltaMillis >= 0) { "Attempted scheduling an event earlier in time (with the time delta $timeDeltaMillis)" }
68         checkSchedulerInContext(this, context)
69         val count = count.getAndIncrement()
70         val isForeground = context[BackgroundWork] === null
71         return synchronized(lock) {
72             val time = addClamping(currentTime, timeDeltaMillis)
73             val event = TestDispatchEvent(dispatcher, count, time, marker as Any, isForeground) { isCancelled(marker) }
74             events.addLast(event)
75             /** can't be moved above: otherwise, [onDispatchEventForeground] or [onDispatchEvent] could consume the
76              * token sent here before there's actually anything in the event queue. */
77             sendDispatchEvent(context)
78             DisposableHandle {
79                 synchronized(lock) {
80                     events.remove(event)
81                 }
82             }
83         }
84     }
85 
86     /**
87      * Runs the next enqueued task, advancing the virtual time to the time of its scheduled awakening,
88      * unless [condition] holds.
89      */
tryRunNextTaskUnlessnull90     internal fun tryRunNextTaskUnless(condition: () -> Boolean): Boolean {
91         val event = synchronized(lock) {
92             if (condition()) return false
93             val event = events.removeFirstOrNull() ?: return false
94             if (currentTime > event.time)
95                 currentTimeAheadOfEvents()
96             currentTime = event.time
97             event
98         }
99         event.dispatcher.processEvent(event.marker)
100         return true
101     }
102 
103     /**
104      * Runs the enqueued tasks in the specified order, advancing the virtual time as needed until there are no more
105      * tasks associated with the dispatchers linked to this scheduler.
106      *
107      * A breaking change from `TestCoroutineDispatcher.advanceTimeBy` is that it no longer returns the total number of
108      * milliseconds by which the execution of this method has advanced the virtual time. If you want to recreate that
109      * functionality, query [currentTime] before and after the execution to achieve the same result.
110      */
<lambda>null111     public fun advanceUntilIdle(): Unit = advanceUntilIdleOr { events.none(TestDispatchEvent<*>::isForeground) }
112 
113     /**
114      * [condition]: guaranteed to be invoked under the lock.
115      */
advanceUntilIdleOrnull116     internal fun advanceUntilIdleOr(condition: () -> Boolean) {
117         while (true) {
118             if (!tryRunNextTaskUnless(condition))
119                 return
120         }
121     }
122 
123     /**
124      * Runs the tasks that are scheduled to execute at this moment of virtual time.
125      */
runCurrentnull126     public fun runCurrent() {
127         val timeMark = synchronized(lock) { currentTime }
128         while (true) {
129             val event = synchronized(lock) {
130                 events.removeFirstIf { it.time <= timeMark } ?: return
131             }
132             event.dispatcher.processEvent(event.marker)
133         }
134     }
135 
136     /**
137      * Moves the virtual clock of this dispatcher forward by [the specified amount][delayTimeMillis], running the
138      * scheduled tasks in the meantime.
139      *
140      * Breaking changes from [TestCoroutineDispatcher.advanceTimeBy]:
141      * - Intentionally doesn't return a `Long` value, as its use cases are unclear. We may restore it in the future;
142      *   please describe your use cases at [the issue tracker](https://github.com/Kotlin/kotlinx.coroutines/issues/).
143      *   For now, it's possible to query [currentTime] before and after execution of this method, to the same effect.
144      * - It doesn't run the tasks that are scheduled at exactly [currentTime] + [delayTimeMillis]. For example,
145      *   advancing the time by one millisecond used to run the tasks at the current millisecond *and* the next
146      *   millisecond, but now will stop just before executing any task starting at the next millisecond.
147      * - Overflowing the target time used to lead to nothing being done, but will now run the tasks scheduled at up to
148      *   (but not including) [Long.MAX_VALUE].
149      *
150      * @throws IllegalArgumentException if passed a negative [delay][delayTimeMillis].
151      */
152     @ExperimentalCoroutinesApi
advanceTimeBynull153     public fun advanceTimeBy(delayTimeMillis: Long): Unit = advanceTimeBy(delayTimeMillis.milliseconds)
154 
155     /**
156      * Moves the virtual clock of this dispatcher forward by [the specified amount][delayTime], running the
157      * scheduled tasks in the meantime.
158      *
159      * @throws IllegalArgumentException if passed a negative [delay][delayTime].
160      */
161     public fun advanceTimeBy(delayTime: Duration) {
162         require(!delayTime.isNegative()) { "Can not advance time by a negative delay: $delayTime" }
163         val startingTime = currentTime
164         val targetTime = addClamping(startingTime, delayTime.inWholeMilliseconds)
165         while (true) {
166             val event = synchronized(lock) {
167                 val timeMark = currentTime
168                 val event = events.removeFirstIf { targetTime > it.time }
169                 when {
170                     event == null -> {
171                         currentTime = targetTime
172                         return
173                     }
174                     timeMark > event.time -> currentTimeAheadOfEvents()
175                     else -> {
176                         currentTime = event.time
177                         event
178                     }
179                 }
180             }
181             event.dispatcher.processEvent(event.marker)
182         }
183     }
184 
185     /**
186      * Checks that the only tasks remaining in the scheduler are cancelled.
187      */
isIdlenull188     internal fun isIdle(strict: Boolean = true): Boolean =
189         synchronized(lock) {
190             if (strict) events.isEmpty else events.none { !it.isCancelled() }
191         }
192 
193     /**
194      * Notifies this scheduler about a dispatch event.
195      *
196      * [context] is the context in which the task will be dispatched.
197      */
sendDispatchEventnull198     internal fun sendDispatchEvent(context: CoroutineContext) {
199         dispatchEvents.trySend(Unit)
200         if (context[BackgroundWork] !== BackgroundWork)
201             dispatchEventsForeground.trySend(Unit)
202     }
203 
204     /**
205      * Waits for a notification about a dispatch event.
206      */
receiveDispatchEventnull207     internal suspend fun receiveDispatchEvent() = dispatchEvents.receive()
208 
209     /**
210      * Consumes the knowledge that a dispatch event happened recently.
211      */
212     internal val onDispatchEvent: SelectClause1<Unit> get() = dispatchEvents.onReceive
213 
214     /**
215      * Consumes the knowledge that a foreground work dispatch event happened recently.
216      */
217     internal val onDispatchEventForeground: SelectClause1<Unit> get() = dispatchEventsForeground.onReceive
218 
219     /**
220      * Returns the [TimeSource] representation of the virtual time of this scheduler.
221      */
222     public val timeSource: TimeSource.WithComparableMarks = object : AbstractLongTimeSource(DurationUnit.MILLISECONDS) {
223         override fun read(): Long = currentTime
224     }
225 }
226 
227 // Some error-throwing functions for pretty stack traces
currentTimeAheadOfEventsnull228 private fun currentTimeAheadOfEvents(): Nothing = invalidSchedulerState()
229 
230 private fun invalidSchedulerState(): Nothing =
231     throw IllegalStateException("The test scheduler entered an invalid state. Please report this at https://github.com/Kotlin/kotlinx.coroutines/issues.")
232 
233 /** [ThreadSafeHeap] node representing a scheduled task, ordered by the planned execution time. */
234 private class TestDispatchEvent<T>(
235     @JvmField val dispatcher: TestDispatcher,
236     private val count: Long,
237     @JvmField val time: Long,
238     @JvmField val marker: T,
239     @JvmField val isForeground: Boolean,
240     // TODO: remove once the deprecated API is gone
241     @JvmField val isCancelled: () -> Boolean
242 ) : Comparable<TestDispatchEvent<*>>, ThreadSafeHeapNode {
243     override var heap: ThreadSafeHeap<*>? = null
244     override var index: Int = 0
245 
246     override fun compareTo(other: TestDispatchEvent<*>) =
247         compareValuesBy(this, other, TestDispatchEvent<*>::time, TestDispatchEvent<*>::count)
248 
249     override fun toString() = "TestDispatchEvent(time=$time, dispatcher=$dispatcher${if (isForeground) "" else ", background"})"
250 }
251 
252 // works with positive `a`, `b`
<lambda>null253 private fun addClamping(a: Long, b: Long): Long = (a + b).let { if (it >= 0) it else Long.MAX_VALUE }
254 
checkSchedulerInContextnull255 internal fun checkSchedulerInContext(scheduler: TestCoroutineScheduler, context: CoroutineContext) {
256     context[TestCoroutineScheduler]?.let {
257         check(it === scheduler) {
258             "Detected use of different schedulers. If you need to use several test coroutine dispatchers, " +
259                 "create one `TestCoroutineScheduler` and pass it to each of them."
260         }
261     }
262 }
263 
264 /**
265  * A coroutine context key denoting that the work is to be executed in the background.
266  * @see [TestScope.backgroundScope]
267  */
268 internal object BackgroundWork : CoroutineContext.Key<BackgroundWork>, CoroutineContext.Element {
269     override val key: CoroutineContext.Key<*>
270         get() = this
271 
toStringnull272     override fun toString(): String = "BackgroundWork"
273 }
274 
275 private fun<T> ThreadSafeHeap<T>.none(predicate: (T) -> Boolean) where T: ThreadSafeHeapNode, T: Comparable<T> =
276     find(predicate) == null
277