xref: /aosp_15_r20/tools/acloud/reconnect/reconnect_test.py (revision 800a58d989c669b8eb8a71d8df53b1ba3d411444)
1# Copyright 2018 - The Android Open Source Project
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Tests for reconnect."""
15
16import collections
17import unittest
18import subprocess
19
20from unittest import mock
21
22from acloud import errors
23from acloud.internal import constants
24from acloud.internal.lib import auth
25from acloud.internal.lib import android_compute_client
26from acloud.internal.lib import cvd_runtime_config
27from acloud.internal.lib import driver_test_lib
28from acloud.internal.lib import gcompute_client
29from acloud.internal.lib import utils
30from acloud.internal.lib import ssh as ssh_object
31from acloud.internal.lib.adb_tools import AdbTools
32from acloud.list import list as list_instance
33from acloud.public import config
34from acloud.reconnect import reconnect
35
36
37ForwardedPorts = collections.namedtuple("ForwardedPorts",
38                                        [constants.VNC_PORT, constants.ADB_PORT])
39
40
41class ReconnectTest(driver_test_lib.BaseDriverTest):
42    """Test reconnect functions."""
43
44    # pylint: disable=no-member, too-many-statements
45    def testReconnectInstance(self):
46        """Test Reconnect Instances."""
47        ssh_private_key_path = "/fake/acloud_rsa"
48        fake_report = mock.MagicMock()
49        instance_object = mock.MagicMock()
50        instance_object.name = "fake_name"
51        instance_object.ip = "1.1.1.1"
52        instance_object.islocal = False
53        instance_object.adb_port = "8686"
54        instance_object.avd_type = "cuttlefish"
55        self.Patch(subprocess, "check_call", return_value=True)
56        self.Patch(utils, "LaunchVncClient")
57        self.Patch(utils, "AutoConnect")
58        self.Patch(AdbTools, "IsAdbConnected", return_value=False)
59        self.Patch(AdbTools, "IsAdbConnectionAlive", return_value=False)
60        self.Patch(utils, "IsCommandRunning", return_value=False)
61        fake_device_dict = {
62            constants.IP: "1.1.1.1",
63            constants.INSTANCE_NAME: "fake_name",
64            constants.VNC_PORT: 6666,
65            constants.ADB_PORT: "8686",
66            constants.DEVICE_SERIAL: "127.0.0.1:8686"
67        }
68
69        # test ssh tunnel not connected, remote instance.
70        instance_object.vnc_port = 6666
71        instance_object.display = ""
72        utils.AutoConnect.call_count = 0
73        reconnect.ReconnectInstance(
74            ssh_private_key_path, instance_object, fake_report, autoconnect="vnc")
75        utils.AutoConnect.assert_not_called()
76        utils.LaunchVncClient.assert_called_with(6666)
77        fake_report.AddData.assert_called_with(key="devices", value=fake_device_dict)
78
79        instance_object.display = "888x777 (99)"
80        utils.AutoConnect.call_count = 0
81        reconnect.ReconnectInstance(
82            ssh_private_key_path, instance_object, fake_report, autoconnect="vnc")
83        utils.AutoConnect.assert_not_called()
84        utils.LaunchVncClient.assert_called_with(6666, "888", "777")
85        fake_report.AddData.assert_called_with(key="devices", value=fake_device_dict)
86
87        # test ssh tunnel connected , remote instance.
88        instance_object.ssh_tunnel_is_connected = False
89        instance_object.display = ""
90        utils.AutoConnect.call_count = 0
91        instance_object.vnc_port = 5555
92        extra_args_ssh_tunnel = None
93        self.Patch(utils, "AutoConnect",
94                   return_value=ForwardedPorts(vnc_port=11111, adb_port=22222))
95        reconnect.ReconnectInstance(
96            ssh_private_key_path, instance_object, fake_report, autoconnect="vnc")
97        utils.AutoConnect.assert_called_with(ip_addr=instance_object.ip,
98                                             rsa_key_file=ssh_private_key_path,
99                                             target_vnc_port=constants.CF_VNC_PORT,
100                                             target_adb_port=constants.CF_ADB_PORT,
101                                             ssh_user=constants.GCE_USER,
102                                             extra_args_ssh_tunnel=extra_args_ssh_tunnel)
103        utils.LaunchVncClient.assert_called_with(11111)
104        fake_device_dict = {
105            constants.IP: "1.1.1.1",
106            constants.INSTANCE_NAME: "fake_name",
107            constants.VNC_PORT: 11111,
108            constants.ADB_PORT: 22222,
109            constants.DEVICE_SERIAL: "127.0.0.1:22222"
110        }
111        fake_report.AddData.assert_called_with(key="devices", value=fake_device_dict)
112
113        instance_object.display = "999x777 (99)"
114        extra_args_ssh_tunnel = "fake_extra_args_ssh_tunnel"
115        utils.AutoConnect.call_count = 0
116        reconnect.ReconnectInstance(
117            ssh_private_key_path, instance_object, fake_report,
118            extra_args_ssh_tunnel=extra_args_ssh_tunnel,
119            autoconnect="vnc")
120        utils.AutoConnect.assert_called_with(ip_addr=instance_object.ip,
121                                             rsa_key_file=ssh_private_key_path,
122                                             target_vnc_port=constants.CF_VNC_PORT,
123                                             target_adb_port=constants.CF_ADB_PORT,
124                                             ssh_user=constants.GCE_USER,
125                                             extra_args_ssh_tunnel=extra_args_ssh_tunnel)
126        utils.LaunchVncClient.assert_called_with(11111, "999", "777")
127        fake_report.AddData.assert_called_with(key="devices", value=fake_device_dict)
128
129        # test fail reconnect report.
130        self.Patch(utils, "AutoConnect",
131                   return_value=ForwardedPorts(vnc_port=None, adb_port=None))
132        reconnect.ReconnectInstance(
133            ssh_private_key_path, instance_object, fake_report, autoconnect="vnc")
134        fake_device_dict = {
135            constants.IP: "1.1.1.1",
136            constants.INSTANCE_NAME: "fake_name",
137            constants.VNC_PORT: None,
138            constants.ADB_PORT: None
139        }
140        fake_report.AddData.assert_called_with(key="device_failing_reconnect",
141                                               value=fake_device_dict)
142
143        # test reconnect local instance.
144        instance_object.islocal = True
145        instance_object.display = ""
146        instance_object.vnc_port = 5555
147        instance_object.ssh_tunnel_is_connected = False
148        utils.AutoConnect.call_count = 0
149        reconnect.ReconnectInstance(
150            ssh_private_key_path, instance_object, fake_report, autoconnect="vnc")
151        utils.AutoConnect.assert_not_called()
152        utils.LaunchVncClient.assert_called_with(5555)
153        fake_device_dict = {
154            constants.IP: "1.1.1.1",
155            constants.INSTANCE_NAME: "fake_name",
156            constants.VNC_PORT: 5555,
157            constants.ADB_PORT: "8686"
158        }
159        fake_report.AddData.assert_called_with(key="devices", value=fake_device_dict)
160
161    # pylint: disable=no-member
162    def testReconnectInstanceWithWebRTC(self):
163        """Test reconnect instances with WebRTC."""
164        ssh_private_key_path = "/fake/acloud_rsa"
165        fake_report = mock.MagicMock()
166        instance_object = mock.MagicMock()
167        instance_object.ip = "1.1.1.1"
168        instance_object.islocal = False
169        instance_object.adb_port = "8686"
170        instance_object.avd_type = "cuttlefish"
171        self.Patch(subprocess, "check_call", return_value=True)
172        self.Patch(utils, "LaunchVncClient")
173        self.Patch(utils, "AutoConnect")
174        self.Patch(utils, "LaunchBrowser")
175        self.Patch(utils, "GetWebrtcPortFromSSHTunnel", return_value=None)
176        self.Patch(utils, "EstablishWebRTCSshTunnel")
177        self.Patch(utils, "PickFreePort", return_value=12345)
178        self.Patch(AdbTools, "IsAdbConnected", return_value=False)
179        self.Patch(AdbTools, "IsAdbConnectionAlive", return_value=False)
180        self.Patch(utils, "IsCommandRunning", return_value=False)
181
182        # test ssh tunnel not reconnect to the remote instance.
183        instance_object.vnc_port = 6666
184        instance_object.display = ""
185        utils.AutoConnect.call_count = 0
186        reconnect.ReconnectInstance(ssh_private_key_path, instance_object, fake_report,
187                                    None, "webrtc")
188        utils.AutoConnect.assert_not_called()
189        utils.LaunchVncClient.assert_not_called()
190        utils.EstablishWebRTCSshTunnel.assert_called_with(extra_args_ssh_tunnel=None,
191                                                          webrtc_local_port=12345,
192                                                          ip_addr='1.1.1.1',
193                                                          rsa_key_file='/fake/acloud_rsa',
194                                                          ssh_user='vsoc-01')
195        utils.LaunchBrowser.assert_called_with('localhost', 12345)
196        utils.PickFreePort.assert_called_once()
197        utils.PickFreePort.reset_mock()
198
199        self.Patch(utils, "GetWebrtcPortFromSSHTunnel", return_value="11111")
200        reconnect.ReconnectInstance(ssh_private_key_path, instance_object, fake_report,
201                                    None, "webrtc")
202        utils.PickFreePort.assert_not_called()
203
204        # local webrtc instance
205        instance_object.islocal = True
206        reconnect.ReconnectInstance(ssh_private_key_path, instance_object, fake_report,
207                                    None, "webrtc")
208        utils.PickFreePort.assert_not_called()
209
210        # autoconnect adb only should launch nothing.
211        utils.LaunchBrowser.reset_mock()
212        utils.LaunchVncClient.reset_mock()
213        reconnect.ReconnectInstance(ssh_private_key_path, instance_object, fake_report,
214                                    None, "adb")
215        utils.LaunchBrowser.assert_not_called()
216        utils.LaunchVncClient.assert_not_called()
217
218
219    def testReconnectInstanceAvdtype(self):
220        """Test Reconnect Instances of avd_type."""
221        ssh_private_key_path = "/fake/acloud_rsa"
222        fake_report = mock.MagicMock()
223        instance_object = mock.MagicMock()
224        instance_object.ip = "1.1.1.1"
225        instance_object.vnc_port = 9999
226        instance_object.adb_port = "9999"
227        instance_object.islocal = False
228        instance_object.ssh_tunnel_is_connected = False
229        self.Patch(utils, "AutoConnect")
230        self.Patch(reconnect, "StartVnc")
231        #test reconnect remote instance when avd_type as gce.
232        instance_object.avd_type = "gce"
233        reconnect.ReconnectInstance(
234            ssh_private_key_path, instance_object, fake_report, autoconnect="vnc")
235        utils.AutoConnect.assert_called_with(ip_addr=instance_object.ip,
236                                             rsa_key_file=ssh_private_key_path,
237                                             target_vnc_port=constants.GCE_VNC_PORT,
238                                             target_adb_port=constants.GCE_ADB_PORT,
239                                             ssh_user=constants.GCE_USER,
240                                             extra_args_ssh_tunnel=None)
241        reconnect.StartVnc.assert_called_once()
242
243        #test reconnect remote instance when avd_type as cuttlefish.
244        instance_object.avd_type = "cuttlefish"
245        reconnect.StartVnc.call_count = 0
246        reconnect.ReconnectInstance(
247            ssh_private_key_path, instance_object, fake_report, autoconnect="vnc")
248        utils.AutoConnect.assert_called_with(ip_addr=instance_object.ip,
249                                             rsa_key_file=ssh_private_key_path,
250                                             target_vnc_port=constants.CF_VNC_PORT,
251                                             target_adb_port=constants.CF_ADB_PORT,
252                                             ssh_user=constants.GCE_USER,
253                                             extra_args_ssh_tunnel=None)
254        reconnect.StartVnc.assert_called_once()
255
256    def testReconnectInstanceUnknownAvdType(self):
257        """Test reconnect instances of unknown avd type."""
258        ssh_private_key_path = "/fake/acloud_rsa"
259        fake_report = mock.MagicMock()
260        instance_object = mock.MagicMock()
261        instance_object.avd_type = "unknown"
262        self.assertRaises(errors.UnknownAvdType,
263                          reconnect.ReconnectInstance,
264                          ssh_private_key_path,
265                          instance_object,
266                          fake_report)
267
268    def testReconnectInstanceNoAvdType(self):
269        """Test reconnect instances with no avd type."""
270        ssh_private_key_path = "/fake/acloud_rsa"
271        fake_report = mock.MagicMock()
272        instance_object = mock.MagicMock()
273        self.assertRaises(errors.UnknownAvdType,
274                          reconnect.ReconnectInstance,
275                          ssh_private_key_path,
276                          instance_object,
277                          fake_report)
278
279    def testStartVnc(self):
280        """Test start Vnc."""
281        self.Patch(subprocess, "check_call", return_value=True)
282        self.Patch(utils, "IsCommandRunning", return_value=False)
283        self.Patch(utils, "LaunchVncClient")
284        vnc_port = 5555
285        display = ""
286        reconnect.StartVnc(vnc_port, display)
287        utils.LaunchVncClient.assert_called_with(5555)
288
289        display = "888x777 (99)"
290        utils.AutoConnect.call_count = 0
291        reconnect.StartVnc(vnc_port, display)
292        utils.LaunchVncClient.assert_called_with(5555, "888", "777")
293        utils.LaunchVncClient.reset_mock()
294
295        self.Patch(utils, "IsCommandRunning", return_value=True)
296        reconnect.StartVnc(vnc_port, display)
297        utils.LaunchVncClient.assert_not_called()
298
299    # pylint: disable=protected-access
300    def testIsWebrtcEnable(self):
301        """Test _IsWebrtcEnable."""
302        fake_ins = mock.MagicMock()
303        fake_ins.islocal = True
304        fake_ins.cf_runtime_cfg = mock.MagicMock()
305        fake_ins.cf_runtime_cfg.enable_webrtc = False
306        reconnect._IsWebrtcEnable(fake_ins, "fake_user", "ssh_pkey_path", "")
307        self.assertFalse(reconnect._IsWebrtcEnable(fake_ins, "fake_user", "ssh_pkey_path", ""))
308
309        fake_ins.islocal = False
310        fake_runtime_config = mock.MagicMock()
311        fake_runtime_config.enable_webrtc = True
312        self.Patch(ssh_object, "Ssh")
313        self.Patch(ssh_object.Ssh, "GetCmdOutput", return_value="fake_rawdata")
314        self.Patch(cvd_runtime_config, "CvdRuntimeConfig",
315                   return_value=fake_runtime_config)
316        self.assertTrue(reconnect._IsWebrtcEnable(fake_ins, "fake_user", "ssh_pkey_path", ""))
317
318        self.Patch(cvd_runtime_config, "CvdRuntimeConfig",
319                   side_effect=errors.ConfigError)
320        self.assertFalse(reconnect._IsWebrtcEnable(fake_ins, "fake_user", "ssh_pkey_path", ""))
321
322    def testRun(self):
323        """Test Run."""
324        fake_args = mock.MagicMock()
325        fake_args.autoconnect = "webrtc"
326        fake_args.instance_names = ["fake-ins-name"]
327        fake_ins1 = mock.MagicMock()
328        fake_ins1.avd_type = "cuttlefish"
329        fake_ins1.islocal = False
330        fake_ins2 = mock.MagicMock()
331        fake_ins2.avd_type = "cuttlefish"
332        fake_ins2.islocal = False
333        fake_ins_gf = mock.MagicMock()
334        fake_ins_gf.avd_type = "goldfish"
335        fake_ins_gf.islocal = False
336        fake_ins_gf.vnc_port = 1234
337        ins_to_reconnect = [fake_ins1]
338        # mock args.all equal to True and return 3 instances.
339        all_ins_to_reconnect = [fake_ins1, fake_ins2, fake_ins_gf]
340        cfg = mock.MagicMock()
341        cfg.ssh_private_key_path = None
342        cfg.extra_args_ssh_tunnel = None
343        self.Patch(config, "GetAcloudConfig", return_value=cfg)
344        self.Patch(list_instance, "GetInstancesFromInstanceNames",
345                   return_value=ins_to_reconnect)
346        self.Patch(list_instance, "ChooseInstances",
347                   return_value=all_ins_to_reconnect)
348        self.Patch(auth, "CreateCredentials")
349        self.Patch(android_compute_client, "AndroidComputeClient")
350        self.Patch(android_compute_client.AndroidComputeClient,
351                   "AddSshRsaInstanceMetadata")
352        self.Patch(reconnect, "ReconnectInstance")
353
354        reconnect.Run(fake_args)
355        list_instance.GetInstancesFromInstanceNames.assert_called_once()
356        self.assertEqual(reconnect.ReconnectInstance.call_count, 1)
357        reconnect.ReconnectInstance.reset_mock()
358
359        # should reconnect all instances
360        fake_args.instance_names = None
361        reconnect.Run(fake_args)
362        list_instance.ChooseInstances.assert_called_once()
363        self.assertEqual(reconnect.ReconnectInstance.call_count, 3)
364        reconnect.ReconnectInstance.reset_mock()
365
366        fake_ins1.islocal = True
367        fake_ins2.avd_type = "unknown"
368        self.Patch(list_instance, "ChooseInstances",
369                   return_value=[fake_ins1, fake_ins2])
370        reconnect.Run(fake_args)
371        self.assertEqual(reconnect.ReconnectInstance.call_count, 1)
372
373    def testGetSshConnectHostname(self):
374        """Test GetSshConnectHostname."""
375        self.Patch(gcompute_client, "GetGCEHostName", return_value="fake_host")
376        instance = mock.MagicMock()
377        instance.islocal = True
378        cfg = mock.MagicMock()
379        self.assertEqual(None, reconnect.GetSshConnectHostname(cfg, instance))
380
381        # Remote instance will get the GCE hostname.
382        instance.islocal = False
383        cfg.connect_hostname = True
384        self.assertEqual("fake_host",
385                         reconnect.GetSshConnectHostname(cfg, instance))
386
387
388if __name__ == "__main__":
389    unittest.main()
390