1 package kotlinx.coroutines.internal
2 
3 import kotlinx.atomicfu.*
4 import kotlinx.coroutines.*
5 
6 /**
7  * @suppress **This an internal API and should not be used from general code.**
8  */
9 @InternalCoroutinesApi
10 public interface ThreadSafeHeapNode {
11     public var heap: ThreadSafeHeap<*>?
12     public var index: Int
13 }
14 
15 /**
16  * Synchronized binary heap.
17  * @suppress **This an internal API and should not be used from general code.**
18  */
19 @InternalCoroutinesApi
20 public open class ThreadSafeHeap<T> : SynchronizedObject() where T: ThreadSafeHeapNode, T: Comparable<T> {
21     private var a: Array<T?>? = null
22 
23     private val _size = atomic(0)
24 
25     public var size: Int
26         get() = _size.value
27         private set(value) { _size.value = value }
28 
29     public val isEmpty: Boolean get() = size == 0
30 
findnull31     public fun find(
32         predicate: (value: T) -> Boolean
33     ): T? = synchronized(this) block@{
34         for (i in 0 until size) {
35             val value = a?.get(i)!!
36             if (predicate(value)) return@block value
37         }
38         null
39     }
40 
<lambda>null41     public fun peek(): T? = synchronized(this) { firstImpl() }
42 
<lambda>null43     public fun removeFirstOrNull(): T? = synchronized(this) {
44         if (size > 0) {
45             removeAtImpl(0)
46         } else {
47             null
48         }
49     }
50 
<lambda>null51     public inline fun removeFirstIf(predicate: (T) -> Boolean): T? = synchronized(this) {
52         val first = firstImpl() ?: return null
53         if (predicate(first)) {
54             removeAtImpl(0)
55         } else {
56             null
57         }
58     }
59 
<lambda>null60     public fun addLast(node: T): Unit = synchronized(this) { addImpl(node) }
61 
62     // Condition also receives current first node in the heap
<lambda>null63     public inline fun addLastIf(node: T, cond: (T?) -> Boolean): Boolean = synchronized(this) {
64         if (cond(firstImpl())) {
65             addImpl(node)
66             true
67         } else {
68             false
69         }
70     }
71 
<lambda>null72     public fun remove(node: T): Boolean = synchronized(this) {
73         return if (node.heap == null) {
74             false
75         } else {
76             val index = node.index
77             assert { index >= 0 }
78             removeAtImpl(index)
79             true
80         }
81     }
82 
83     @PublishedApi
firstImplnull84     internal fun firstImpl(): T? = a?.get(0)
85 
86     @PublishedApi
87     internal fun removeAtImpl(index: Int): T {
88         assert { size > 0 }
89         val a = this.a!!
90         size--
91         if (index < size) {
92             swap(index, size)
93             val j = (index - 1) / 2
94             if (index > 0 && a[index]!! < a[j]!!) {
95                 swap(index, j)
96                 siftUpFrom(j)
97             } else {
98                 siftDownFrom(index)
99             }
100         }
101         val result = a[size]!!
102         assert { result.heap === this }
103         result.heap = null
104         result.index = -1
105         a[size] = null
106         return result
107     }
108 
109     @PublishedApi
addImplnull110     internal fun addImpl(node: T) {
111         assert { node.heap == null }
112         node.heap = this
113         val a = realloc()
114         val i = size++
115         a[i] = node
116         node.index = i
117         siftUpFrom(i)
118     }
119 
siftUpFromnull120     private tailrec fun siftUpFrom(i: Int) {
121         if (i <= 0) return
122         val a = a!!
123         val j = (i - 1) / 2
124         if (a[j]!! <= a[i]!!) return
125         swap(i, j)
126         siftUpFrom(j)
127     }
128 
siftDownFromnull129     private tailrec fun siftDownFrom(i: Int) {
130         var j = 2 * i + 1
131         if (j >= size) return
132         val a = a!!
133         if (j + 1 < size && a[j + 1]!! < a[j]!!) j++
134         if (a[i]!! <= a[j]!!) return
135         swap(i, j)
136         siftDownFrom(j)
137     }
138 
139     @Suppress("UNCHECKED_CAST")
reallocnull140     private fun realloc(): Array<T?> {
141         val a = this.a
142         return when {
143             a == null -> (arrayOfNulls<ThreadSafeHeapNode>(4) as Array<T?>).also { this.a = it }
144             size >= a.size -> a.copyOf(size * 2).also { this.a = it }
145             else -> a
146         }
147     }
148 
swapnull149     private fun swap(i: Int, j: Int) {
150         val a = a!!
151         val ni = a[j]!!
152         val nj = a[i]!!
153         a[i] = ni
154         a[j] = nj
155         ni.index = i
156         nj.index = j
157     }
158 }
159