1 package kotlinx.coroutines.scheduling 2 3 import kotlinx.atomicfu.* 4 import kotlinx.coroutines.* 5 import java.util.concurrent.atomic.* 6 import kotlin.jvm.internal.Ref.ObjectRef 7 8 internal const val BUFFER_CAPACITY_BASE = 7 9 internal const val BUFFER_CAPACITY = 1 shl BUFFER_CAPACITY_BASE 10 internal const val MASK = BUFFER_CAPACITY - 1 // 128 by default 11 12 internal const val TASK_STOLEN = -1L 13 internal const val NOTHING_TO_STEAL = -2L 14 15 internal typealias StealingMode = Int 16 internal const val STEAL_ANY: StealingMode = 3 17 internal const val STEAL_CPU_ONLY: StealingMode = 2 18 internal const val STEAL_BLOCKING_ONLY: StealingMode = 1 19 20 internal inline val Task.maskForStealingMode: Int 21 get() = if (isBlocking) STEAL_BLOCKING_ONLY else STEAL_CPU_ONLY 22 23 /** 24 * Tightly coupled with [CoroutineScheduler] queue of pending tasks, but extracted to separate file for simplicity. 25 * At any moment queue is used only by [CoroutineScheduler.Worker] threads, has only one producer (worker owning this queue) 26 * and any amount of consumers, other pool workers which are trying to steal work. 27 * 28 * ### Fairness 29 * 30 * [WorkQueue] provides semi-FIFO order, but with priority for most recently submitted task assuming 31 * that these two (current one and submitted) are communicating and sharing state thus making such communication extremely fast. 32 * E.g. submitted jobs [1, 2, 3, 4] will be executed in [4, 1, 2, 3] order. 33 * 34 * ### Algorithm and implementation details 35 * This is a regular SPMC bounded queue with the additional property that tasks can be removed from the middle of the queue 36 * (scheduler workers without a CPU permit steal blocking tasks via this mechanism). Such property enforces us to use CAS in 37 * order to properly claim value from the buffer. 38 * Moreover, [Task] objects are reusable, so it may seem that this queue is prone to ABA problem. 39 * Indeed, it formally has ABA-problem, but the whole processing logic is written in the way that such ABA is harmless. 40 * I have discovered a truly marvelous proof of this, which this KDoc is too narrow to contain. 41 */ 42 internal class WorkQueue { 43 44 /* 45 * We read two independent counter here. 46 * Producer index is incremented only by owner 47 * Consumer index is incremented both by owner and external threads 48 * 49 * The only harmful race is: 50 * [T1] readProducerIndex (1) preemption(2) readConsumerIndex(5) 51 * [T2] changeProducerIndex (3) 52 * [T3] changeConsumerIndex (4) 53 * 54 * Which can lead to resulting size being negative or bigger than actual size at any moment of time. 55 * This is in general harmless because steal will be blocked by timer. 56 * Negative sizes can be observed only when non-owner reads the size, which happens only 57 * for diagnostic toString(). 58 */ 59 private val bufferSize: Int get() = producerIndex.value - consumerIndex.value 60 internal val size: Int get() = if (lastScheduledTask.value != null) bufferSize + 1 else bufferSize 61 private val buffer: AtomicReferenceArray<Task?> = AtomicReferenceArray(BUFFER_CAPACITY) 62 private val lastScheduledTask = atomic<Task?>(null) 63 64 private val producerIndex = atomic(0) 65 private val consumerIndex = atomic(0) 66 // Shortcut to avoid scanning queue without blocking tasks 67 private val blockingTasksInBuffer = atomic(0) 68 69 /** 70 * Retrieves and removes task from the head of the queue 71 * Invariant: this method is called only by the owner of the queue. 72 */ pollnull73 fun poll(): Task? = lastScheduledTask.getAndSet(null) ?: pollBuffer() 74 75 /** 76 * Invariant: Called only by the owner of the queue, returns 77 * `null` if task was added, task that wasn't added otherwise. 78 */ 79 fun add(task: Task, fair: Boolean = false): Task? { 80 if (fair) return addLast(task) 81 val previous = lastScheduledTask.getAndSet(task) ?: return null 82 return addLast(previous) 83 } 84 85 /** 86 * Invariant: Called only by the owner of the queue, returns 87 * `null` if task was added, task that wasn't added otherwise. 88 */ addLastnull89 private fun addLast(task: Task): Task? { 90 if (bufferSize == BUFFER_CAPACITY - 1) return task 91 if (task.isBlocking) blockingTasksInBuffer.incrementAndGet() 92 val nextIndex = producerIndex.value and MASK 93 /* 94 * If current element is not null then we're racing with a really slow consumer that committed the consumer index, 95 * but hasn't yet nulled out the slot, effectively preventing us from using it. 96 * Such situations are very rare in practise (although possible) and we decided to give up a progress guarantee 97 * to have a stronger invariant "add to queue with bufferSize == 0 is always successful". 98 * This algorithm can still be wait-free for add, but if and only if tasks are not reusable, otherwise 99 * nulling out the buffer wouldn't be possible. 100 */ 101 while (buffer[nextIndex] != null) { 102 Thread.yield() 103 } 104 buffer.lazySet(nextIndex, task) 105 producerIndex.incrementAndGet() 106 return null 107 } 108 109 /** 110 * Tries stealing from this queue into the [stolenTaskRef] argument. 111 * 112 * Returns [NOTHING_TO_STEAL] if queue has nothing to steal, [TASK_STOLEN] if at least task was stolen 113 * or positive value of how many nanoseconds should pass until the head of this queue will be available to steal. 114 * 115 * [StealingMode] controls what tasks to steal: 116 * - [STEAL_ANY] is default mode for scheduler, task from the head (in FIFO order) is stolen 117 * - [STEAL_BLOCKING_ONLY] is mode for stealing *an arbitrary* blocking task, which is used by the scheduler when helping in Dispatchers.IO mode 118 * - [STEAL_CPU_ONLY] is a kludge for `runSingleTaskFromCurrentSystemDispatcher` 119 */ tryStealnull120 fun trySteal(stealingMode: StealingMode, stolenTaskRef: ObjectRef<Task?>): Long { 121 val task = when (stealingMode) { 122 STEAL_ANY -> pollBuffer() 123 else -> stealWithExclusiveMode(stealingMode) 124 } 125 126 if (task != null) { 127 stolenTaskRef.element = task 128 return TASK_STOLEN 129 } 130 return tryStealLastScheduled(stealingMode, stolenTaskRef) 131 } 132 133 // Steal only tasks of a particular kind, potentially invoking full queue scan stealWithExclusiveModenull134 private fun stealWithExclusiveMode(stealingMode: StealingMode): Task? { 135 var start = consumerIndex.value 136 val end = producerIndex.value 137 val onlyBlocking = stealingMode == STEAL_BLOCKING_ONLY 138 // Bail out if there is no blocking work for us 139 while (start != end) { 140 if (onlyBlocking && blockingTasksInBuffer.value == 0) return null 141 return tryExtractFromTheMiddle(start++, onlyBlocking) ?: continue 142 } 143 144 return null 145 } 146 147 // Polls for blocking task, invoked only by the owner 148 // NB: ONLY for runSingleTask method pollBlockingnull149 fun pollBlocking(): Task? = pollWithExclusiveMode(onlyBlocking = true /* only blocking */) 150 151 // Polls for CPU task, invoked only by the owner 152 // NB: ONLY for runSingleTask method 153 fun pollCpu(): Task? = pollWithExclusiveMode(onlyBlocking = false /* only cpu */) 154 155 private fun pollWithExclusiveMode(/* Only blocking OR only CPU */ onlyBlocking: Boolean): Task? { 156 while (true) { // Poll the slot 157 val lastScheduled = lastScheduledTask.value ?: break 158 if (lastScheduled.isBlocking != onlyBlocking) break 159 if (lastScheduledTask.compareAndSet(lastScheduled, null)) { 160 return lastScheduled 161 } // Failed -> someone else stole it 162 } 163 164 // Failed to poll the slot, scan the queue 165 val start = consumerIndex.value 166 var end = producerIndex.value 167 // Bail out if there is no blocking work for us 168 while (start != end) { 169 if (onlyBlocking && blockingTasksInBuffer.value == 0) return null 170 val task = tryExtractFromTheMiddle(--end, onlyBlocking) 171 if (task != null) { 172 return task 173 } 174 } 175 return null 176 } 177 tryExtractFromTheMiddlenull178 private fun tryExtractFromTheMiddle(index: Int, onlyBlocking: Boolean): Task? { 179 val arrayIndex = index and MASK 180 val value = buffer[arrayIndex] 181 if (value != null && value.isBlocking == onlyBlocking && buffer.compareAndSet(arrayIndex, value, null)) { 182 if (onlyBlocking) blockingTasksInBuffer.decrementAndGet() 183 return value 184 } 185 return null 186 } 187 offloadAllWorkTonull188 fun offloadAllWorkTo(globalQueue: GlobalQueue) { 189 lastScheduledTask.getAndSet(null)?.let { globalQueue.addLast(it) } 190 while (pollTo(globalQueue)) { 191 // Steal everything 192 } 193 } 194 195 /** 196 * Contract on return value is the same as for [trySteal] 197 */ tryStealLastSchedulednull198 private fun tryStealLastScheduled(stealingMode: StealingMode, stolenTaskRef: ObjectRef<Task?>): Long { 199 while (true) { 200 val lastScheduled = lastScheduledTask.value ?: return NOTHING_TO_STEAL 201 if ((lastScheduled.maskForStealingMode and stealingMode) == 0) { 202 return NOTHING_TO_STEAL 203 } 204 205 // TODO time wraparound ? 206 val time = schedulerTimeSource.nanoTime() 207 val staleness = time - lastScheduled.submissionTime 208 if (staleness < WORK_STEALING_TIME_RESOLUTION_NS) { 209 return WORK_STEALING_TIME_RESOLUTION_NS - staleness 210 } 211 212 /* 213 * If CAS has failed, either someone else had stolen this task or the owner executed this task 214 * and dispatched another one. In the latter case we should retry to avoid missing task. 215 */ 216 if (lastScheduledTask.compareAndSet(lastScheduled, null)) { 217 stolenTaskRef.element = lastScheduled 218 return TASK_STOLEN 219 } 220 continue 221 } 222 } 223 pollTonull224 private fun pollTo(queue: GlobalQueue): Boolean { 225 val task = pollBuffer() ?: return false 226 queue.addLast(task) 227 return true 228 } 229 pollBuffernull230 private fun pollBuffer(): Task? { 231 while (true) { 232 val tailLocal = consumerIndex.value 233 if (tailLocal - producerIndex.value == 0) return null 234 val index = tailLocal and MASK 235 if (consumerIndex.compareAndSet(tailLocal, tailLocal + 1)) { 236 // Nulls are allowed when blocking tasks are stolen from the middle of the queue. 237 val value = buffer.getAndSet(index, null) ?: continue 238 value.decrementIfBlocking() 239 return value 240 } 241 } 242 } 243 Tasknull244 private fun Task?.decrementIfBlocking() { 245 if (this != null && isBlocking) { 246 val value = blockingTasksInBuffer.decrementAndGet() 247 assert { value >= 0 } 248 } 249 } 250 } 251