1 /*
<lambda>null2  * Copyright (C) 2023 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.systemui.communal.data.db
18 
19 import android.content.ComponentName
20 import androidx.room.Room
21 import androidx.test.ext.junit.runners.AndroidJUnit4
22 import androidx.test.filters.SmallTest
23 import com.android.systemui.SysuiTestCase
24 import com.android.systemui.communal.nano.CommunalHubState
25 import com.android.systemui.communal.shared.model.SpanValue
26 import com.android.systemui.coroutines.collectLastValue
27 import com.android.systemui.lifecycle.InstantTaskExecutorRule
28 import com.google.common.truth.Truth.assertThat
29 import java.io.IOException
30 import kotlinx.coroutines.test.StandardTestDispatcher
31 import kotlinx.coroutines.test.TestScope
32 import kotlinx.coroutines.test.runTest
33 import org.junit.After
34 import org.junit.Before
35 import org.junit.Rule
36 import org.junit.Test
37 import org.junit.runner.RunWith
38 import org.mockito.MockitoAnnotations
39 
40 @SmallTest
41 @RunWith(AndroidJUnit4::class)
42 class CommunalWidgetDaoTest : SysuiTestCase() {
43     @JvmField @Rule val instantTaskExecutor = InstantTaskExecutorRule()
44 
45     private lateinit var db: CommunalDatabase
46     private lateinit var communalWidgetDao: CommunalWidgetDao
47 
48     private val testDispatcher = StandardTestDispatcher()
49     private val testScope = TestScope(testDispatcher)
50 
51     @Before
52     @Throws(IOException::class)
53     fun setUp() {
54         MockitoAnnotations.initMocks(this)
55         db =
56             Room.inMemoryDatabaseBuilder(context, CommunalDatabase::class.java)
57                 .allowMainThreadQueries()
58                 .build()
59         communalWidgetDao = db.communalWidgetDao()
60     }
61 
62     @After
63     @Throws(IOException::class)
64     fun teardown() {
65         db.close()
66     }
67 
68     @Test
69     fun addWidget_readValueInDb() =
70         testScope.runTest {
71             val (widgetId, provider, rank, userSerialNumber, spanY) = widgetInfo1
72             communalWidgetDao.addWidget(
73                 widgetId = widgetId,
74                 provider = provider,
75                 rank = rank,
76                 userSerialNumber = userSerialNumber,
77                 spanY = spanY,
78             )
79             val entry = communalWidgetDao.getWidgetByIdNow(id = 1)
80             assertThat(entry).isEqualTo(communalWidgetItemEntry1)
81         }
82 
83     @Test
84     fun deleteWidget_notInDb_returnsFalse() =
85         testScope.runTest {
86             val (widgetId, provider, rank, userSerialNumber, spanY) = widgetInfo1
87             communalWidgetDao.addWidget(
88                 widgetId = widgetId,
89                 provider = provider,
90                 rank = rank,
91                 userSerialNumber = userSerialNumber,
92                 spanY = spanY,
93             )
94             assertThat(communalWidgetDao.deleteWidgetById(widgetId = 123)).isFalse()
95         }
96 
97     @Test
98     fun addWidget_emitsActiveWidgetsInDb(): Unit =
99         testScope.runTest {
100             val widgetsToAdd = listOf(widgetInfo1, widgetInfo2)
101             val widgets = collectLastValue(communalWidgetDao.getWidgets())
102             widgetsToAdd.forEach {
103                 val (widgetId, provider, rank, userSerialNumber, spanY) = it
104                 communalWidgetDao.addWidget(
105                     widgetId = widgetId,
106                     provider = provider,
107                     rank = rank,
108                     userSerialNumber = userSerialNumber,
109                     spanY = spanY,
110                 )
111             }
112             assertThat(widgets())
113                 .containsExactly(
114                     communalItemRankEntry1,
115                     communalWidgetItemEntry1,
116                     communalItemRankEntry2,
117                     communalWidgetItemEntry2,
118                 )
119         }
120 
121     @Test
122     fun addWidget_rankNotSpecified_widgetAddedAtTheEnd(): Unit =
123         testScope.runTest {
124             val widgets by collectLastValue(communalWidgetDao.getWidgets())
125 
126             // Verify database is empty
127             assertThat(widgets).isEmpty()
128 
129             // Add widgets one by one without specifying rank
130             val widgetsToAdd = listOf(widgetInfo1, widgetInfo2, widgetInfo3)
131             widgetsToAdd.forEach {
132                 val (widgetId, provider, _, userSerialNumber, spanY) = it
133                 communalWidgetDao.addWidget(
134                     widgetId = widgetId,
135                     provider = provider,
136                     userSerialNumber = userSerialNumber,
137                     spanY = spanY,
138                 )
139             }
140 
141             // Verify new each widget is added at the end
142             assertThat(widgets)
143                 .containsExactly(
144                     communalItemRankEntry1,
145                     communalWidgetItemEntry1,
146                     communalItemRankEntry2,
147                     communalWidgetItemEntry2,
148                     communalItemRankEntry3,
149                     communalWidgetItemEntry3,
150                 )
151         }
152 
153     @Test
154     fun deleteWidget_emitsActiveWidgetsInDb() =
155         testScope.runTest {
156             val widgetsToAdd = listOf(widgetInfo1, widgetInfo2)
157             val widgets = collectLastValue(communalWidgetDao.getWidgets())
158 
159             widgetsToAdd.forEach {
160                 val (widgetId, provider, rank, userSerialNumber, spanY) = it
161                 communalWidgetDao.addWidget(
162                     widgetId = widgetId,
163                     provider = provider,
164                     rank = rank,
165                     userSerialNumber = userSerialNumber,
166                     spanY = spanY,
167                 )
168             }
169             assertThat(widgets())
170                 .containsExactly(
171                     communalItemRankEntry1,
172                     communalWidgetItemEntry1,
173                     communalItemRankEntry2,
174                     communalWidgetItemEntry2,
175                 )
176 
177             communalWidgetDao.deleteWidgetById(communalWidgetItemEntry1.widgetId)
178             assertThat(widgets()).containsExactly(communalItemRankEntry2, communalWidgetItemEntry2)
179         }
180 
181     @Test
182     fun reorderWidget_emitsWidgetsInNewOrder() =
183         testScope.runTest {
184             val widgetsToAdd = listOf(widgetInfo1, widgetInfo2)
185             val widgets = collectLastValue(communalWidgetDao.getWidgets())
186 
187             widgetsToAdd.forEach {
188                 val (widgetId, provider, rank, userSerialNumber, spanY) = it
189                 communalWidgetDao.addWidget(
190                     widgetId = widgetId,
191                     provider = provider,
192                     rank = rank,
193                     userSerialNumber = userSerialNumber,
194                     spanY = spanY,
195                 )
196             }
197             assertThat(widgets())
198                 .containsExactly(
199                     communalItemRankEntry1,
200                     communalWidgetItemEntry1,
201                     communalItemRankEntry2,
202                     communalWidgetItemEntry2,
203                 )
204                 .inOrder()
205 
206             // swapped ranks
207             val widgetIdsToRankMap = mapOf(widgetInfo1.widgetId to 1, widgetInfo2.widgetId to 0)
208             communalWidgetDao.updateWidgetOrder(widgetIdsToRankMap)
209             assertThat(widgets())
210                 .containsExactly(
211                     communalItemRankEntry2.copy(rank = 0),
212                     communalWidgetItemEntry2,
213                     communalItemRankEntry1.copy(rank = 1),
214                     communalWidgetItemEntry1,
215                 )
216                 .inOrder()
217         }
218 
219     @Test
220     fun addNewWidgetWithReorder_emitsWidgetsInNewOrder() =
221         testScope.runTest {
222             val existingWidgets = listOf(widgetInfo1, widgetInfo2, widgetInfo3)
223             val widgets = collectLastValue(communalWidgetDao.getWidgets())
224 
225             existingWidgets.forEach {
226                 val (widgetId, provider, rank, userSerialNumber, spanY) = it
227                 communalWidgetDao.addWidget(
228                     widgetId = widgetId,
229                     provider = provider,
230                     rank = rank,
231                     userSerialNumber = userSerialNumber,
232                     spanY = spanY,
233                 )
234             }
235             assertThat(widgets())
236                 .containsExactly(
237                     communalItemRankEntry1,
238                     communalWidgetItemEntry1,
239                     communalItemRankEntry2,
240                     communalWidgetItemEntry2,
241                     communalItemRankEntry3,
242                     communalWidgetItemEntry3,
243                 )
244                 .inOrder()
245 
246             // add a new widget at rank 1.
247             communalWidgetDao.addWidget(
248                 widgetId = 4,
249                 provider = ComponentName("pk_name", "cls_name_4"),
250                 rank = 1,
251                 userSerialNumber = 0,
252                 spanY = SpanValue.Responsive(1),
253             )
254 
255             val newRankEntry = CommunalItemRank(uid = 4L, rank = 1)
256             val newWidgetEntry =
257                 CommunalWidgetItem(
258                     uid = 4L,
259                     widgetId = 4,
260                     componentName = "pk_name/cls_name_4",
261                     itemId = 4L,
262                     userSerialNumber = 0,
263                     spanY = 3,
264                     spanYNew = 1,
265                 )
266             assertThat(widgets())
267                 .containsExactly(
268                     communalItemRankEntry1.copy(rank = 0),
269                     communalWidgetItemEntry1,
270                     newRankEntry,
271                     newWidgetEntry,
272                     communalItemRankEntry2.copy(rank = 2),
273                     communalWidgetItemEntry2,
274                     communalItemRankEntry3.copy(rank = 3),
275                     communalWidgetItemEntry3,
276                 )
277                 .inOrder()
278         }
279 
280     @Test
281     fun addWidget_withDifferentSpanY_readsCorrectValuesInDb() =
282         testScope.runTest {
283             val widgets = collectLastValue(communalWidgetDao.getWidgets())
284 
285             // Add widgets with different spanY values
286             communalWidgetDao.addWidget(
287                 widgetId = 1,
288                 provider = ComponentName("pkg_name", "cls_name_1"),
289                 rank = 0,
290                 userSerialNumber = 0,
291                 spanY = SpanValue.Responsive(1),
292             )
293             communalWidgetDao.addWidget(
294                 widgetId = 2,
295                 provider = ComponentName("pkg_name", "cls_name_2"),
296                 rank = 1,
297                 userSerialNumber = 0,
298                 spanY = SpanValue.Responsive(2),
299             )
300             communalWidgetDao.addWidget(
301                 widgetId = 3,
302                 provider = ComponentName("pkg_name", "cls_name_3"),
303                 rank = 2,
304                 userSerialNumber = 0,
305                 spanY = SpanValue.Fixed(3),
306             )
307 
308             // Verify that the widgets have the correct spanY values
309             assertThat(widgets())
310                 .containsExactly(
311                     CommunalItemRank(uid = 1L, rank = 0),
312                     CommunalWidgetItem(
313                         uid = 1L,
314                         widgetId = 1,
315                         componentName = "pkg_name/cls_name_1",
316                         itemId = 1L,
317                         userSerialNumber = 0,
318                         spanY = 3,
319                         spanYNew = 1,
320                     ),
321                     CommunalItemRank(uid = 2L, rank = 1),
322                     CommunalWidgetItem(
323                         uid = 2L,
324                         widgetId = 2,
325                         componentName = "pkg_name/cls_name_2",
326                         itemId = 2L,
327                         userSerialNumber = 0,
328                         spanY = 6,
329                         spanYNew = 2,
330                     ),
331                     CommunalItemRank(uid = 3L, rank = 2),
332                     CommunalWidgetItem(
333                         uid = 3L,
334                         widgetId = 3,
335                         componentName = "pkg_name/cls_name_3",
336                         itemId = 3L,
337                         userSerialNumber = 0,
338                         spanY = 3,
339                         spanYNew = 1,
340                     ),
341                 )
342                 .inOrder()
343         }
344 
345     @Test
346     fun restoreCommunalHubState() =
347         testScope.runTest {
348             // Set up db
349             listOf(widgetInfo1, widgetInfo2, widgetInfo3).forEach { addWidget(it) }
350 
351             // Restore db to fake state
352             communalWidgetDao.restoreCommunalHubState(fakeState)
353 
354             // Verify db matches new state
355             val expected = mutableMapOf<CommunalItemRank, CommunalWidgetItem>()
356             fakeState.widgets.forEachIndexed { index, fakeWidget ->
357                 // Auto-generated uid continues after the initial 3 widgets and starts at 4
358                 val uid = index + 4L
359                 val rank = CommunalItemRank(uid = uid, rank = fakeWidget.rank)
360                 val widget =
361                     CommunalWidgetItem(
362                         uid = uid,
363                         widgetId = fakeWidget.widgetId,
364                         componentName = fakeWidget.componentName,
365                         itemId = rank.uid,
366                         userSerialNumber = fakeWidget.userSerialNumber,
367                         spanY = fakeWidget.spanY.coerceAtLeast(3),
368                         spanYNew = fakeWidget.spanYNew.coerceAtLeast(1),
369                     )
370                 expected[rank] = widget
371             }
372             val widgets by collectLastValue(communalWidgetDao.getWidgets())
373             assertThat(widgets).containsExactlyEntriesIn(expected)
374         }
375 
376     private fun addWidget(metadata: FakeWidgetMetadata, rank: Int? = null) {
377         communalWidgetDao.addWidget(
378             widgetId = metadata.widgetId,
379             provider = metadata.provider,
380             rank = rank ?: metadata.rank,
381             userSerialNumber = metadata.userSerialNumber,
382             spanY = metadata.spanY,
383         )
384     }
385 
386     data class FakeWidgetMetadata(
387         val widgetId: Int,
388         val provider: ComponentName,
389         val rank: Int,
390         val userSerialNumber: Int,
391         val spanY: SpanValue,
392     )
393 
394     companion object {
395         val widgetInfo1 =
396             FakeWidgetMetadata(
397                 widgetId = 1,
398                 provider = ComponentName("pk_name", "cls_name_1"),
399                 rank = 0,
400                 userSerialNumber = 0,
401                 spanY = SpanValue.Responsive(1),
402             )
403         val widgetInfo2 =
404             FakeWidgetMetadata(
405                 widgetId = 2,
406                 provider = ComponentName("pk_name", "cls_name_2"),
407                 rank = 1,
408                 userSerialNumber = 0,
409                 spanY = SpanValue.Responsive(1),
410             )
411         val widgetInfo3 =
412             FakeWidgetMetadata(
413                 widgetId = 3,
414                 provider = ComponentName("pk_name", "cls_name_3"),
415                 rank = 2,
416                 userSerialNumber = 10,
417                 spanY = SpanValue.Responsive(1),
418             )
419         val communalItemRankEntry1 = CommunalItemRank(uid = 1L, rank = widgetInfo1.rank)
420         val communalItemRankEntry2 = CommunalItemRank(uid = 2L, rank = widgetInfo2.rank)
421         val communalItemRankEntry3 = CommunalItemRank(uid = 3L, rank = widgetInfo3.rank)
422         val communalWidgetItemEntry1 =
423             CommunalWidgetItem(
424                 uid = 1L,
425                 widgetId = widgetInfo1.widgetId,
426                 componentName = widgetInfo1.provider.flattenToString(),
427                 itemId = communalItemRankEntry1.uid,
428                 userSerialNumber = widgetInfo1.userSerialNumber,
429                 spanY = 3,
430                 spanYNew = 1,
431             )
432         val communalWidgetItemEntry2 =
433             CommunalWidgetItem(
434                 uid = 2L,
435                 widgetId = widgetInfo2.widgetId,
436                 componentName = widgetInfo2.provider.flattenToString(),
437                 itemId = communalItemRankEntry2.uid,
438                 userSerialNumber = widgetInfo2.userSerialNumber,
439                 spanY = 3,
440                 spanYNew = 1,
441             )
442         val communalWidgetItemEntry3 =
443             CommunalWidgetItem(
444                 uid = 3L,
445                 widgetId = widgetInfo3.widgetId,
446                 componentName = widgetInfo3.provider.flattenToString(),
447                 itemId = communalItemRankEntry3.uid,
448                 userSerialNumber = widgetInfo3.userSerialNumber,
449                 spanY = 3,
450                 spanYNew = 1,
451             )
452         val fakeState =
453             CommunalHubState().apply {
454                 widgets =
455                     listOf(
456                             CommunalHubState.CommunalWidgetItem().apply {
457                                 widgetId = 1
458                                 componentName = "pk_name/fake_widget_1"
459                                 rank = 1
460                                 userSerialNumber = 0
461                                 spanY = 3
462                             },
463                             CommunalHubState.CommunalWidgetItem().apply {
464                                 widgetId = 2
465                                 componentName = "pk_name/fake_widget_2"
466                                 rank = 2
467                                 userSerialNumber = 10
468                                 spanYNew = 1
469                             },
470                         )
471                         .toTypedArray()
472             }
473     }
474 }
475