1 /*
2  * Copyright (C) 2024 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 package com.google.snippet.wifi.aware;
17 
18 import android.content.Context;
19 import android.net.ConnectivityManager;
20 import android.net.Network;
21 import android.net.NetworkCapabilities;
22 import android.net.NetworkRequest;
23 import android.net.TransportInfo;
24 import android.net.wifi.aware.WifiAwareChannelInfo;
25 import android.net.wifi.aware.WifiAwareNetworkInfo;
26 
27 import androidx.annotation.NonNull;
28 import androidx.test.core.app.ApplicationProvider;
29 
30 import com.google.android.mobly.snippet.Snippet;
31 import com.google.android.mobly.snippet.event.EventCache;
32 import com.google.android.mobly.snippet.event.SnippetEvent;
33 import com.google.android.mobly.snippet.rpc.AsyncRpc;
34 import com.google.android.mobly.snippet.rpc.Rpc;
35 import com.google.android.mobly.snippet.util.Log;
36 
37 import org.json.JSONException;
38 
39 import java.io.IOException;
40 import java.io.InputStream;
41 import java.io.OutputStream;
42 import java.net.Inet6Address;
43 import java.net.ServerSocket;
44 import java.net.Socket;
45 import java.nio.charset.StandardCharsets;
46 import java.util.ArrayList;
47 import java.util.List;
48 import java.util.concurrent.ConcurrentHashMap;
49 
50 public class ConnectivityManagerSnippet implements Snippet {
51     private static final String EVENT_KEY_CB_NAME = "callbackName";
52     private static final String EVENT_KEY_NETWORK = "network";
53     private static final String EVENT_KEY_NETWORK_CAP = "networkCapabilities";
54     private static final String EVENT_KEY_TRANSPORT_INFO_CLASS = "transportInfoClassName";
55     private static final String EVENT_KEY_TRANSPORT_INFO_CHANNEL_IN_MHZ = "channelInMhz";
56     private static final int CLOSE_SOCKET_TIMEOUT = 15 * 1000;
57     private static final int ACCEPT_TIMEOUT = 30 * 1000;
58     private static final int SOCKET_SO_TIMEOUT = 30 * 1000;
59     private static final int TRANSPORT_PROTOCOL_TCP = 6;
60 
61     private final Context mContext;
62     private final ConnectivityManager mConnectivityManager;
63 
64     private final ConcurrentHashMap<String, ServerSocket> mServerSockets =
65             new ConcurrentHashMap<>();
66     private final ConcurrentHashMap<String, NetworkCallback> mNetworkCallBacks =
67             new ConcurrentHashMap<>();
68     private final ConcurrentHashMap<String, Socket> mSockets = new ConcurrentHashMap<>();
69     private final ConcurrentHashMap<String, OutputStream> mOutputStreams =
70             new ConcurrentHashMap<>();
71     private final ConcurrentHashMap<String, InputStream> mInputStreams = new ConcurrentHashMap<>();
72     private final ConcurrentHashMap<String, Thread> mSocketThreads = new ConcurrentHashMap<>();
73 
74     /**
75      * Custom exception class for handling specific errors related to the ConnectivityManagerSnippet
76      * operations.
77      */
78     class ConnectivityManagerSnippetException extends Exception {
ConnectivityManagerSnippetException(String msg)79         ConnectivityManagerSnippetException(String msg) {
80             super(msg);
81         }
82     }
83 
ConnectivityManagerSnippet()84     public ConnectivityManagerSnippet() throws ConnectivityManagerSnippetException {
85         mContext = ApplicationProvider.getApplicationContext();
86         mConnectivityManager = mContext.getSystemService(ConnectivityManager.class);
87         if (mConnectivityManager == null) {
88             throw new ConnectivityManagerSnippetException(
89                     "ConnectivityManager not " + "available.");
90         }
91     }
92 
93     public class NetworkCallback extends ConnectivityManager.NetworkCallback {
94 
95 
96         String mCallBackId;
97         Network mNetWork;
98         NetworkCapabilities mNetworkCapabilities;
99 
100 
NetworkCallback(String callBackId)101         NetworkCallback(String callBackId) {
102             mCallBackId = callBackId;
103         }
104 
105         @Override
onUnavailable()106         public void onUnavailable() {
107             SnippetEvent event = new SnippetEvent(mCallBackId, "NetworkCallback");
108             event.getData().putString(EVENT_KEY_CB_NAME, "onUnavailable");
109             EventCache.getInstance().postEvent(event);
110         }
111 
112         @Override
onCapabilitiesChanged(@onNull Network network, @NonNull NetworkCapabilities networkCapabilities)113         public void onCapabilitiesChanged(@NonNull Network network,
114                 @NonNull NetworkCapabilities networkCapabilities) {
115             SnippetEvent event = new SnippetEvent(mCallBackId, "NetworkCallback");
116             event.getData().putString(EVENT_KEY_CB_NAME, "onCapabilitiesChanged");
117             event.getData().putParcelable(EVENT_KEY_NETWORK, network);
118             event.getData().putParcelable(EVENT_KEY_NETWORK_CAP, networkCapabilities);
119             mNetWork = network;
120             mNetworkCapabilities = networkCapabilities;
121             TransportInfo transportInfo = networkCapabilities.getTransportInfo();
122             String transportInfoClassName = "";
123             if (transportInfo != null) {
124                 transportInfoClassName = transportInfo.getClass().getName();
125                 event.getData().putString(EVENT_KEY_TRANSPORT_INFO_CLASS, transportInfoClassName);
126             }
127             if (networkCapabilities.getTransportInfo() instanceof WifiAwareNetworkInfo) {
128                 WifiAwareNetworkInfo
129                         newWorkInfo =
130                         (WifiAwareNetworkInfo) networkCapabilities.getTransportInfo();
131                 List<WifiAwareChannelInfo> channelInfoList = newWorkInfo.getChannelInfoList();
132                 ArrayList<Integer> channelFrequencies = new ArrayList<>();
133                 if (!channelInfoList.isEmpty()) {
134                     for (WifiAwareChannelInfo info : channelInfoList) {
135                         channelFrequencies.add(info.getChannelFrequencyMhz());
136                     }
137                 }
138                 event.getData().putIntegerArrayList(
139                     EVENT_KEY_TRANSPORT_INFO_CHANNEL_IN_MHZ, channelFrequencies
140                 );
141 
142             }
143             EventCache.getInstance().postEvent(event);
144         }
145     }
146 
147     /**
148      * Requests a network with the specified network request and sets a callback for network
149      * events.
150      *
151      * @param callBackId              A unique identifier assigned automatically by Mobly. This is
152      *                                used as the request ID for further operations and event
153      *                                handling.
154      * @param request                 The NetworkRequest object that specifies the desired network
155      *                                characteristics.
156      * @param requestNetWorkId        A unique ID to support managing multiple network sessions.
157      * @param requestNetworkTimeoutMs The timeout period (in milliseconds) after which the network
158      *                                request will expire if no suitable network is found.
159      */
160     @AsyncRpc(description = "Request a network.")
connectivityRequestNetwork(String callBackId, String requestNetWorkId, NetworkRequest request, int requestNetworkTimeoutMs)161     public void connectivityRequestNetwork(String callBackId, String requestNetWorkId,
162             NetworkRequest request, int requestNetworkTimeoutMs) {
163         Log.v("Requesting network with request: " + request.toString());
164         NetworkCallback callback = new NetworkCallback(callBackId);
165         mNetworkCallBacks.put(requestNetWorkId, callback);
166         mConnectivityManager.requestNetwork(request, callback, requestNetworkTimeoutMs);
167     }
168 
169     /**
170      * Unregisters the registered network callback and possibly releases requested networks.
171      *
172      * @param requestId Id of the network request.
173      */
174     @Rpc(description = "Unregister a network request")
connectivityUnregisterNetwork(String requestId)175     public void connectivityUnregisterNetwork(String requestId) {
176         NetworkCallback callback = mNetworkCallBacks.get(requestId);
177         if (callback == null) {
178             return;
179         }
180         if (mConnectivityManager == null) {
181             return;
182         }
183         mConnectivityManager.unregisterNetworkCallback(callback);
184     }
185 
186     /**
187      * Starts a server socket on a random available port and waits for incoming connections. A
188      * separate thread is started to handle the socket accept operation asynchronously. The accepted
189      * socket is stored and used for further communication (read/write).
190      *
191      * @param callbackId A unique identifier assigned automatically by Mobly to track the event and
192      *                   response.
193      * @return The port number assigned by the local system.
194      */
195     @AsyncRpc(description = "Start a server socket to accept incoming connections.")
connectivityServerSocketAccept(String callbackId)196     public int connectivityServerSocketAccept(String callbackId)
197             throws ConnectivityManagerSnippetException, IOException {
198         if (mServerSockets.containsKey(callbackId) && mServerSockets.get(callbackId) != null) {
199             throw new ConnectivityManagerSnippetException("Server socket is already created.");
200         }
201         ServerSocket serverSocket = new ServerSocket(0);
202         int localPort = serverSocket.getLocalPort();
203         mServerSockets.put(callbackId, serverSocket);
204         // https://developer.callbackId.com/reference/java/net/ServerSocket#setSoTimeout(int)
205         // A call to accept() for this ServerSocket will block for only this amount of time.
206         serverSocket.setSoTimeout(ACCEPT_TIMEOUT);
207         if (mSocketThreads.get(callbackId) != null) {
208             throw new ConnectivityManagerSnippetException(
209                     "Server socket thread is already running.");
210         }
211         Thread socketThread = new Thread(() -> {
212             try {
213                 Socket tempSocket = mServerSockets.get(callbackId).accept();
214                 mSockets.put(callbackId, tempSocket);
215                 mInputStreams.put(callbackId, tempSocket.getInputStream());
216                 mOutputStreams.put(callbackId, tempSocket.getOutputStream());
217                 SnippetEvent event = new SnippetEvent(callbackId, "ServerSocketAccept");
218                 event.getData().putBoolean("isAccept", true);
219                 EventCache.getInstance().postEvent(event);
220             } catch (IOException e) {
221                 Log.e("Socket accept error", e);
222                 SnippetEvent event = new SnippetEvent(callbackId, "ServerSocketAccept");
223                 event.getData().putBoolean("isAccept", false);
224                 event.getData().putString("error", e.getMessage());
225                 EventCache.getInstance().postEvent(event);
226             }
227         });
228         mSocketThreads.put(callbackId, socketThread);
229         socketThread.start();
230         return localPort;
231     }
232 
233     /**
234      * Check if the server socket thread is alive.
235      *
236      * @param sessionId To support multiple network requests happening simultaneously
237      * @return True if the server socket thread is alive.
238      */
connectivityIsSocketThreadAlive(String sessionId)239     public boolean connectivityIsSocketThreadAlive(String sessionId) {
240         Thread thread = mSocketThreads.get(sessionId);
241         if (thread != null) {
242             return thread.isAlive();
243         } else {
244             return false;
245         }
246     }
247 
248     /**
249      * Stops the server socket thread if it's running.
250      *
251      * @param sessionId To support multiple network requests happening simultaneously
252      */
253     @Rpc(description = "Stop the server socket thread if it's running.")
connectivityStopAcceptThread(String sessionId)254     public void connectivityStopAcceptThread(String sessionId) throws IOException {
255         if (connectivityIsSocketThreadAlive(sessionId)) {
256             Thread thread = mSocketThreads.get(sessionId);
257 
258             try {
259                 connectivityCloseServerSocket(sessionId);
260                 thread.join(CLOSE_SOCKET_TIMEOUT);  // Wait for the thread to terminate
261                 if (thread.isAlive()) {
262                     throw new RuntimeException("Server socket thread did not terminate in time");
263                 }
264             } catch (InterruptedException e) {
265                 throw new RuntimeException("Error stopping server socket thread", e);
266             } finally {
267                 connectivityCloseSocket(sessionId);
268                 mSocketThreads.remove(sessionId);
269             }
270         } else {
271             connectivityCloseSocket(sessionId);
272             mSocketThreads.remove(sessionId);
273         }
274     }
275 
276     /**
277      * Reads from a socket.
278      *
279      * @param sessionId To support multiple network requests happening simultaneously
280      * @param len       The number of bytes to read.
281      */
282     @Rpc(description = "Reads from a socket.")
connectivityReadSocket(String sessionId, int len)283     public String connectivityReadSocket(String sessionId, int len)
284             throws ConnectivityManagerSnippetException, JSONException, IOException {
285         checkInputStream(sessionId);
286         // Read the specified number of bytes from the input stream
287         byte[] buffer = new byte[len];
288         InputStream inputStream = mInputStreams.get(sessionId);
289         int bytesReadLength = inputStream.read(buffer, 0, len); // Read up to len bytes
290         if (bytesReadLength == -1) { // End of stream reached unexpectedly
291             throw new ConnectivityManagerSnippetException(
292                     "End of stream reached before reading expected bytes.");
293         }
294         // Convert the bytes read to a String
295         String receiveStrMsg = new String(buffer, 0, bytesReadLength, StandardCharsets.UTF_8);
296         return receiveStrMsg;
297     }
298 
299     /**
300      * Writes to a socket.
301      *
302      * @param sessionId To support multiple network requests happening simultaneously
303      * @param message   The message to send.
304      * @throws ConnectivityManagerSnippetException
305      */
306     @Rpc(description = "Writes to a socket.")
connectivityWriteSocket(String sessionId, String message)307     public Boolean connectivityWriteSocket(String sessionId, String message)
308             throws ConnectivityManagerSnippetException, IOException {
309         checkOutputStream(sessionId);
310         byte[] bytes = message.getBytes(StandardCharsets.UTF_8);
311         // Write the message to the output stream
312         OutputStream outputStream = mOutputStreams.get(sessionId);
313         outputStream.write(bytes, 0, bytes.length);
314         outputStream.flush();
315         return true;
316 
317 
318     }
319 
320     /**
321      * Closes the socket.
322      *
323      * @param sessionId To support multiple network requests happening simultaneously
324      * @throws ConnectivityManagerSnippetException
325      */
connectivityCloseSocket(String sessionId)326     public void connectivityCloseSocket(String sessionId) throws IOException {
327         Socket socket = mSockets.get(sessionId);
328         if (socket != null && !socket.isClosed()) {
329             socket.close();
330         }
331         mSockets.remove(sessionId);
332 
333     }
334 
335     /**
336      * Closes the server socket.
337      *
338      * @param sessionId To support multiple network requests happening simultaneously
339      * @throws IOException
340      */
connectivityCloseServerSocket(String sessionId)341     public void connectivityCloseServerSocket(String sessionId) throws IOException {
342         ServerSocket serverSocket = mServerSockets.get(sessionId);
343         if (serverSocket != null && !serverSocket.isClosed()) {
344             serverSocket.close();
345         }
346         mServerSockets.remove(sessionId);
347     }
348 
349     /**
350      * Closes the outputStream.
351      *
352      * @throws ConnectivityManagerSnippetException
353      */
354     @Rpc(description = "Close the outputStream.")
connectivityCloseWrite(String sessionId)355     public void connectivityCloseWrite(String sessionId)
356             throws IOException, ConnectivityManagerSnippetException {
357         OutputStream outputStream = mOutputStreams.get(sessionId);
358         if (outputStream != null) {
359             outputStream.close();
360         }
361         mOutputStreams.remove(sessionId);
362 
363 
364     }
365 
366     /**
367      * Closes the inputStream.
368      *
369      * @throws ConnectivityManagerSnippetException
370      */
371     @Rpc(description = "Close the inputStream.")
connectivityCloseRead(String sessionId)372     public void connectivityCloseRead(String sessionId)
373             throws IOException, ConnectivityManagerSnippetException {
374         InputStream inputStream = mInputStreams.get(sessionId);
375         if (inputStream != null) {
376             inputStream.close();
377         }
378         mInputStreams.remove(sessionId);
379     }
380 
checkOutputStream(String sessionId)381     private void checkOutputStream(String sessionId) throws ConnectivityManagerSnippetException {
382         OutputStream outputStream = mOutputStreams.get(sessionId);
383         if (outputStream == null) {
384             throw new ConnectivityManagerSnippetException("Output stream is not created.Please "
385                     + "call connectivityCreateSocketOverWiFiAware() or "
386                     + "connectivityServerSocketAccept() first.");
387         }
388     }
389 
checkInputStream(String sessionId)390     private void checkInputStream(String sessionId) throws ConnectivityManagerSnippetException {
391         InputStream inputStream = mInputStreams.get(sessionId);
392         if (inputStream == null) {
393             throw new ConnectivityManagerSnippetException("Input stream is not created.Please "
394                     + "call connectivityCreateSocketOverWiFiAware() or "
395                     + "connectivityServerSocketAccept() first.");
396         }
397     }
398 
399     /**
400      * Creates a socket using Wi-Fi Aware's peer-to-peer connection capabilities. Only TCP transport
401      * protocol is supported. The method uses the session ID to track and manage the socket.
402      *
403      * @param sessionId     A unique ID to manage multiple network requests simultaneously.
404      * @param peerLocalPort The port number of the peer device.
405      */
406     @Rpc(description = "Create to a socket.")
connectivityCreateSocketOverWiFiAware(String sessionId, int peerLocalPort)407     public void connectivityCreateSocketOverWiFiAware(String sessionId, int peerLocalPort)
408             throws ConnectivityManagerSnippetException, IOException {
409         NetworkCallback netWorkCallBackBySessionId = getNetWorkCallbackBySessionId(sessionId);
410         NetworkCapabilities networkCapabilities = netWorkCallBackBySessionId.mNetworkCapabilities;
411         Network netWork = netWorkCallBackBySessionId.mNetWork;
412         checkNetworkCapabilities(networkCapabilities);
413         checkNetwork(netWork);
414         Socket socket = mSockets.get(sessionId);
415         if (socket != null) {
416             throw new ConnectivityManagerSnippetException("Socket is already created"
417                     + ".Please call connectivityCloseSocket(String sessionId) or "
418                     + "connectivityStopAcceptThread" + "(String sessionId) " + "to release first.");
419         }
420 
421         checkNetworkCapabilities(networkCapabilities);
422         WifiAwareNetworkInfo peerAwareInfo =
423                 (WifiAwareNetworkInfo) networkCapabilities.getTransportInfo();
424         if (peerAwareInfo == null) {
425             throw new ConnectivityManagerSnippetException("PeerAwareInfo is null.");
426         }
427         int peerPort = peerAwareInfo.getPort();
428         Inet6Address peerIpv6Addr = peerAwareInfo.getPeerIpv6Addr();
429         if (peerPort == 0) {
430             peerPort = peerLocalPort;
431             if (peerPort == 0) {
432                 throw new ConnectivityManagerSnippetException("Invalid port number.");
433             }
434         } else {
435 
436             int transportProtocol = peerAwareInfo.getTransportProtocol();
437             if (transportProtocol != TRANSPORT_PROTOCOL_TCP) {
438                 throw new ConnectivityManagerSnippetException(
439                         "Only support TCP transport protocol.");
440             }
441         }
442 
443 
444         Socket createSocket = netWork.getSocketFactory().createSocket(peerIpv6Addr, peerPort);
445         createSocket.setSoTimeout(SOCKET_SO_TIMEOUT);
446         mSockets.put(sessionId, createSocket);
447         mInputStreams.put(sessionId, createSocket.getInputStream());
448         mOutputStreams.put(sessionId, createSocket.getOutputStream());
449     }
450 
451 
getNetWorkCallbackBySessionId(String sessionId)452     private NetworkCallback getNetWorkCallbackBySessionId(String sessionId)
453             throws ConnectivityManagerSnippetException {
454         NetworkCallback callback = mNetworkCallBacks.get(sessionId);
455         if (callback == null) {
456             throw new ConnectivityManagerSnippetException("Network callback is not created.Please "
457                     + "call connectivityRequestNetwork() first.");
458 
459         }
460         return callback;
461     }
462 
463     /**
464      * Check if the network capabilities is created.
465      *
466      * @throws ConnectivityManagerSnippetException
467      */
checkNetworkCapabilities(NetworkCapabilities networkCapabilities)468     private void checkNetworkCapabilities(NetworkCapabilities networkCapabilities)
469             throws ConnectivityManagerSnippetException {
470         if (networkCapabilities == null) {
471             throw new ConnectivityManagerSnippetException("Network capabilities is not created.");
472         }
473     }
474 
475     /**
476      * Check if the network is created.
477      *
478      * @throws ConnectivityManagerSnippetException
479      */
checkNetwork(Network network)480     private void checkNetwork(Network network) throws ConnectivityManagerSnippetException {
481         if (network == null) {
482             throw new ConnectivityManagerSnippetException("Network is not created.");
483         }
484     }
485 
486     /**
487      * Check if the server socket is created.
488      *
489      * @throws ConnectivityManagerSnippetException
490      */
checkServerSocket(String sessionId)491     private void checkServerSocket(String sessionId) throws ConnectivityManagerSnippetException {
492         if (mServerSockets.get(sessionId) == null) {
493             throw new ConnectivityManagerSnippetException("Server socket is not created"
494                     + ".Please call connectivityInitServerSocket() first.");
495         }
496     }
497 
498     /**
499      * Close all sockets.
500      *
501      * @param sessionId To support multiple network requests happening simultaneously
502      * @throws IOException
503      */
504     @Rpc(description = "Close all sockets.")
connectivityCloseAllSocket(String sessionId)505     public void connectivityCloseAllSocket(String sessionId)
506             throws IOException, ConnectivityManagerSnippetException {
507         connectivityStopAcceptThread(sessionId);
508         connectivityCloseServerSocket(sessionId);
509         connectivityCloseRead(sessionId);
510         connectivityCloseWrite(sessionId);
511     }
512 
513     @Override
shutdown()514     public void shutdown() throws Exception {
515         try {
516             for (NetworkCallback callback : mNetworkCallBacks.values()) {
517                 mConnectivityManager.unregisterNetworkCallback(callback);
518             }
519             mNetworkCallBacks.clear();
520 
521         } catch (Exception e) {
522             Log.e("Error unregistering network callback", e);
523         }
524         try {
525             connectivityReleaseAllSockets();
526         } catch (Exception e) {
527             Log.e("Error closing sockets", e);
528         }
529         Snippet.super.shutdown();
530     }
531 
532     /**
533      * Close all sockets.
534      *
535      * @throws IOException
536      */
537     @Rpc(description = "Close all sockets.")
connectivityReleaseAllSockets()538     public void connectivityReleaseAllSockets() {
539         for (Socket socket : mSockets.values()) {
540             try {
541                 if (socket != null && !socket.isClosed()) {
542                     socket.close();
543                 }
544             } catch (IOException e) {
545                 Log.e("Error closing socket", e);
546             }
547         }
548         mSockets.clear();
549         for (ServerSocket serverSocket : mServerSockets.values()) {
550             try {
551                 if (serverSocket != null && !serverSocket.isClosed()) {
552                     serverSocket.close();
553                 }
554             } catch (IOException e) {
555                 Log.e("Error closing server socket", e);
556             }
557         }
558         mServerSockets.clear();
559         for (OutputStream outputStream : mOutputStreams.values()) {
560             try {
561                 if (outputStream != null) {
562                     outputStream.close();
563                 }
564             } catch (IOException e) {
565                 Log.e("Error closing output stream", e);
566             }
567         }
568         mOutputStreams.clear();
569         for (InputStream inputStream : mInputStreams.values()) {
570             try {
571                 if (inputStream != null) {
572                     inputStream.close();
573                 }
574             } catch (IOException e) {
575                 Log.e("Error closing input stream", e);
576             }
577         }
578         mInputStreams.clear();
579     }
580 }
581