1 /*
2  * Copyright (C) 2024 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.testutils
18 
19 import android.content.Context
20 import android.content.pm.PackageManager
21 import android.net.ConnectivityManager
22 import android.net.Network
23 import android.net.NetworkCapabilities
24 import com.android.internal.annotations.VisibleForTesting
25 import com.android.net.module.util.BitUtils
26 import java.util.concurrent.CompletableFuture
27 import java.util.concurrent.TimeUnit
28 import org.junit.runner.Description
29 import org.junit.runner.notification.Failure
30 import org.junit.runner.notification.RunListener
31 import org.junit.runner.notification.RunNotifier
32 
33 @VisibleForTesting(visibility = VisibleForTesting.Visibility.PRIVATE)
34 class DefaultNetworkRestoreMonitor(
35         ctx: Context,
36         private val notifier: RunNotifier,
37         private val timeoutMs: Long = 30_000
38 ) {
39     var firstFailure: Exception? = null
40     var initialTransports = 0L
41     val cm = ctx.getSystemService(ConnectivityManager::class.java)!!
42     val pm = ctx.packageManager
43     val listener = object : RunListener() {
testFinishednull44         override fun testFinished(desc: Description) {
45             // Only the first method that does not restore the default network should be blamed.
46             if (firstFailure != null) {
47                 return
48             }
49             val cb = TestableNetworkCallback()
50             cm.registerDefaultNetworkCallback(cb)
51             try {
52                 cb.eventuallyExpect<RecorderCallback.CallbackEntry.CapabilitiesChanged>(
53                     timeoutMs = timeoutMs
54                 ) {
55                     BitUtils.packBits(it.caps.transportTypes) == initialTransports &&
56                             it.caps.hasCapability(NetworkCapabilities.NET_CAPABILITY_VALIDATED)
57                 }
58             } catch (e: AssertionError) {
59                 firstFailure = IllegalStateException(desc.methodName + " does not restore the" +
60                         "default network, initialTransports = $initialTransports", e)
61             } finally {
62                 cm.unregisterNetworkCallback(cb)
63             }
64         }
65     }
66 
initnull67     fun init(connectUtil: ConnectUtil) {
68         // Ensure Wi-Fi and cellular connection before running test to avoid starting test
69         // with unexpected default network.
70         // ConnectivityTestTargetPreparer does the same thing, but it's possible that previous tests
71         // don't enable DefaultNetworkRestoreMonitor and the default network is not restored.
72         // This can be removed if all tests enable DefaultNetworkRestoreMonitor
73         if (pm.hasSystemFeature(PackageManager.FEATURE_WIFI)) {
74             connectUtil.ensureWifiValidated()
75         }
76         if (pm.hasSystemFeature(PackageManager.FEATURE_TELEPHONY)) {
77             connectUtil.ensureCellularValidated()
78         }
79 
80         val capFuture = CompletableFuture<NetworkCapabilities>()
81         val cb = object : ConnectivityManager.NetworkCallback() {
82             override fun onCapabilitiesChanged(
83                     network: Network,
84                     cap: NetworkCapabilities
85             ) {
86                 capFuture.complete(cap)
87             }
88         }
89         cm.registerDefaultNetworkCallback(cb)
90         try {
91             val cap = capFuture.get(10_000, TimeUnit.MILLISECONDS)
92             initialTransports = BitUtils.packBits(cap.transportTypes)
93         } catch (e: Exception) {
94             firstFailure = IllegalStateException(
95                     "Failed to get default network status before starting tests", e
96             )
97         } finally {
98             cm.unregisterNetworkCallback(cb)
99         }
100         notifier.addListener(listener)
101     }
102 
reportResultAndCleanUpnull103     fun reportResultAndCleanUp(desc: Description) {
104         notifier.fireTestStarted(desc)
105         if (firstFailure != null) {
106             notifier.fireTestFailure(
107                     Failure(desc, firstFailure)
108             )
109         }
110         notifier.fireTestFinished(desc)
111         notifier.removeListener(listener)
112     }
113 }
114