1 /**
2  * Copyright (C) 2023 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 package com.android.net.module.util
17 
18 import android.os.Message
19 import androidx.test.ext.junit.runners.AndroidJUnit4
20 import androidx.test.filters.SmallTest
21 import com.android.internal.util.State
22 import com.android.net.module.util.SyncStateMachine.StateInfo
23 import java.util.ArrayDeque
24 import java.util.ArrayList
25 import kotlin.test.assertFailsWith
26 import org.junit.Assert.assertEquals
27 import org.junit.Assert.assertTrue
28 import org.junit.Test
29 import org.junit.runner.RunWith
30 import org.mockito.ArgumentMatchers.any
31 import org.mockito.Mockito.inOrder
32 import org.mockito.Mockito.spy
33 import org.mockito.Mockito.verifyNoMoreInteractions
34 
35 private const val MSG_INVALID = -1
36 private const val MSG_1 = 1
37 private const val MSG_2 = 2
38 private const val MSG_3 = 3
39 private const val MSG_4 = 4
40 private const val MSG_5 = 5
41 private const val MSG_6 = 6
42 private const val MSG_7 = 7
43 private const val ARG_1 = 100
44 private const val ARG_2 = 200
45 
46 @RunWith(AndroidJUnit4::class)
47 @SmallTest
48 class SyncStateMachineTest {
49     private val mState1 = spy(object : TestState(MSG_1) {})
50     private val mState2 = spy(object : TestState(MSG_2) {})
51     private val mState3 = spy(object : TestState(MSG_3) {})
52     private val mState4 = spy(object : TestState(MSG_4) {})
53     private val mState5 = spy(object : TestState(MSG_5) {})
54     private val mState6 = spy(object : TestState(MSG_6) {})
55     private val mState7 = spy(object : TestState(MSG_7) {})
56     private val mInOrder = inOrder(mState1, mState2, mState3, mState4, mState5, mState6, mState7)
57     // Lazy initialize to make sure running in test thread.
<lambda>null58     private val mSM by lazy {
59         SyncStateMachine("TestSyncStateMachine", Thread.currentThread(), true /* debug */)
60     }
61     private val mAllStates = ArrayList<StateInfo>()
62 
63     private val mMsgProcessedResults = ArrayDeque<Pair<State, Int>>()
64 
65     open inner class TestState(val expected: Int) : State() {
66         // Control destination state in obj field for testing.
processMessagenull67         override fun processMessage(msg: Message): Boolean {
68             mMsgProcessedResults.add(this to msg.what)
69             assertEquals(ARG_1, msg.arg1)
70             assertEquals(ARG_2, msg.arg2)
71 
72             if (msg.what == expected) {
73                 msg.obj?.let { mSM.transitionTo(it as State) }
74                 return true
75             }
76 
77             return false
78         }
79     }
80 
verifyNoMoreInteractionsnull81     private fun verifyNoMoreInteractions() {
82         verifyNoMoreInteractions(mState1, mState2, mState3, mState4, mState5, mState6)
83     }
84 
processMessagenull85     private fun processMessage(what: Int, toState: State?) {
86         mSM.processMessage(what, ARG_1, ARG_2, toState)
87     }
88 
verifyMessageProcessedBynull89     private fun verifyMessageProcessedBy(what: Int, vararg processedStates: State) {
90         for (state in processedStates) {
91             // InOrder.verify can't check the Message content here because SyncSM will recycle the
92             // message after it's been processed. SyncSM reuses the same Message instance for all
93             // messages it processes. So, if using InOrder.verify to verify the content of a message
94             // after SyncSM has processed it, the content would be wrong.
95             mInOrder.verify(state).processMessage(any())
96             val (processedState, msgWhat) = mMsgProcessedResults.remove()
97             assertEquals(state, processedState)
98             assertEquals(what, msgWhat)
99         }
100         assertTrue(mMsgProcessedResults.isEmpty())
101     }
102 
103     @Test
testInitialStatenull104     fun testInitialState() {
105         // mState1 -> initial
106         //    |
107         // mState2
108         mAllStates.add(StateInfo(mState1, null))
109         mAllStates.add(StateInfo(mState2, mState1))
110         mSM.addAllStates(mAllStates)
111 
112         mSM.start(mState1)
113         mInOrder.verify(mState1).enter()
114         verifyNoMoreInteractions()
115     }
116 
117     @Test
testStartFromLeafStatenull118     fun testStartFromLeafState() {
119         // mState1 -> initial
120         //    |
121         // mState2
122         //    |
123         // mState3
124         mAllStates.add(StateInfo(mState1, null))
125         mAllStates.add(StateInfo(mState2, mState1))
126         mAllStates.add(StateInfo(mState3, mState2))
127         mSM.addAllStates(mAllStates)
128 
129         mSM.start(mState3)
130         mInOrder.verify(mState1).enter()
131         mInOrder.verify(mState2).enter()
132         mInOrder.verify(mState3).enter()
133         verifyNoMoreInteractions()
134     }
135 
verifyStartnull136     private fun verifyStart() {
137         mSM.addAllStates(mAllStates)
138         mSM.start(mState1)
139         mInOrder.verify(mState1).enter()
140         verifyNoMoreInteractions()
141     }
142 
addStatenull143     fun addState(state: State, parent: State? = null) {
144         mAllStates.add(StateInfo(state, parent))
145     }
146 
147     @Test
testAddStatenull148     fun testAddState() {
149         // Add duplicated states.
150         mAllStates.add(StateInfo(mState1, null))
151         mAllStates.add(StateInfo(mState1, null))
152         assertFailsWith(IllegalStateException::class) {
153             mSM.addAllStates(mAllStates)
154         }
155     }
156 
157     @Test
testProcessMessagenull158     fun testProcessMessage() {
159         // mState1
160         //    |
161         // mState2
162         addState(mState1)
163         addState(mState2, mState1)
164         verifyStart()
165 
166         processMessage(MSG_1, null)
167         verifyMessageProcessedBy(MSG_1, mState1)
168         verifyNoMoreInteractions()
169     }
170 
171     @Test
testTwoStatesnull172     fun testTwoStates() {
173         // mState1 <-initial, mState2
174         addState(mState1)
175         addState(mState2)
176         verifyStart()
177 
178         // Test transition to mState2
179         processMessage(MSG_1, mState2)
180         verifyMessageProcessedBy(MSG_1, mState1)
181         mInOrder.verify(mState1).exit()
182         mInOrder.verify(mState2).enter()
183         verifyNoMoreInteractions()
184 
185         // If set destState to mState2 (current state), no state transition.
186         processMessage(MSG_2, mState2)
187         verifyMessageProcessedBy(MSG_2, mState2)
188         verifyNoMoreInteractions()
189     }
190 
191     @Test
testTwoStateTreesnull192     fun testTwoStateTrees() {
193         //    mState1 -> initial  mState4
194         //    /     \             /     \
195         // mState2 mState3     mState5 mState6
196         addState(mState1)
197         addState(mState2, mState1)
198         addState(mState3, mState1)
199         addState(mState4)
200         addState(mState5, mState4)
201         addState(mState6, mState4)
202         verifyStart()
203 
204         //    mState1 -> current     mState4
205         //    /     \                /     \
206         // mState2 mState3 -> dest mState5 mState6
207         processMessage(MSG_1, mState3)
208         verifyMessageProcessedBy(MSG_1, mState1)
209         mInOrder.verify(mState3).enter()
210         verifyNoMoreInteractions()
211 
212         //           mState1                     mState4
213         //           /     \                     /     \
214         // dest <- mState2 mState3 -> current mState5 mState6
215         processMessage(MSG_1, mState2)
216         verifyMessageProcessedBy(MSG_1, mState3, mState1)
217         mInOrder.verify(mState3).exit()
218         mInOrder.verify(mState2).enter()
219         verifyNoMoreInteractions()
220 
221         //               mState1          mState4
222         //               /     \          /     \
223         // current <- mState2 mState3 mState5 mState6 -> dest
224         processMessage(MSG_2, mState6)
225         verifyMessageProcessedBy(MSG_2, mState2)
226         mInOrder.verify(mState2).exit()
227         mInOrder.verify(mState1).exit()
228         mInOrder.verify(mState4).enter()
229         mInOrder.verify(mState6).enter()
230         verifyNoMoreInteractions()
231     }
232 
233     @Test
testMultiDepthTransitionnull234     fun testMultiDepthTransition() {
235         //      mState1 -> current
236         //    |          \
237         //  mState2         mState6
238         //    |   \           |
239         //  mState3 mState5  mState7
240         //    |
241         //  mState4
242         addState(mState1)
243         addState(mState2, mState1)
244         addState(mState6, mState1)
245         addState(mState3, mState2)
246         addState(mState5, mState2)
247         addState(mState7, mState6)
248         addState(mState4, mState3)
249         verifyStart()
250 
251         //      mState1 -> current
252         //    |          \
253         //  mState2         mState6
254         //    |   \           |
255         //  mState3 mState5  mState7
256         //    |
257         //  mState4 -> dest
258         processMessage(MSG_1, mState4)
259         verifyMessageProcessedBy(MSG_1, mState1)
260         mInOrder.verify(mState2).enter()
261         mInOrder.verify(mState3).enter()
262         mInOrder.verify(mState4).enter()
263         verifyNoMoreInteractions()
264 
265         //            mState1
266         //        /            \
267         //  mState2             mState6
268         //    |   \                 \
269         //  mState3 mState5 -> dest  mState7
270         //    |
271         //  mState4 -> current
272         processMessage(MSG_1, mState5)
273         verifyMessageProcessedBy(MSG_1, mState4, mState3, mState2, mState1)
274         mInOrder.verify(mState4).exit()
275         mInOrder.verify(mState3).exit()
276         mInOrder.verify(mState5).enter()
277         verifyNoMoreInteractions()
278 
279         //            mState1
280         //        /              \
281         //  mState2               mState6
282         //    |   \                    \
283         //  mState3 mState5 -> current  mState7 -> dest
284         //    |
285         //  mState4
286         processMessage(MSG_2, mState7)
287         verifyMessageProcessedBy(MSG_2, mState5, mState2)
288         mInOrder.verify(mState5).exit()
289         mInOrder.verify(mState2).exit()
290         mInOrder.verify(mState6).enter()
291         mInOrder.verify(mState7).enter()
292         verifyNoMoreInteractions()
293     }
294 }
295