1# Copyright 2021 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14import logging
15from typing import List
16
17from absl import flags
18from absl.testing import absltest
19
20from framework import xds_k8s_flags
21from framework import xds_k8s_testcase
22from framework.helpers import skips
23from framework.infrastructure import k8s
24from framework.test_app.runners.k8s import k8s_xds_server_runner
25
26logger = logging.getLogger(__name__)
27flags.adopt_module_key_flags(xds_k8s_testcase)
28
29# Type aliases
30_Lang = skips.Lang
31_XdsTestServer = xds_k8s_testcase.XdsTestServer
32_XdsTestClient = xds_k8s_testcase.XdsTestClient
33_KubernetesServerRunner = k8s_xds_server_runner.KubernetesServerRunner
34
35
36class FailoverTest(xds_k8s_testcase.RegularXdsKubernetesTestCase):
37    REPLICA_COUNT = 3
38    MAX_RATE_PER_ENDPOINT = 100
39
40    @classmethod
41    def setUpClass(cls):
42        super().setUpClass()
43        # Force the python client to use the reference server image (Java)
44        # because the python server doesn't yet support set_not_serving RPC.
45        # TODO(https://github.com/grpc/grpc/issues/30635): Remove when resolved.
46        if cls.lang_spec.client_lang == _Lang.PYTHON:
47            cls.server_image = xds_k8s_flags.SERVER_IMAGE_CANONICAL.value
48
49    def setUp(self):
50        super().setUp()
51        self.secondary_server_runner = _KubernetesServerRunner(
52            k8s.KubernetesNamespace(self.secondary_k8s_api_manager,
53                                    self.server_namespace),
54            deployment_name=self.server_name + '-alt',
55            image_name=self.server_image,
56            gcp_service_account=self.gcp_service_account,
57            td_bootstrap_image=self.td_bootstrap_image,
58            gcp_project=self.project,
59            gcp_api_manager=self.gcp_api_manager,
60            xds_server_uri=self.xds_server_uri,
61            network=self.network,
62            debug_use_port_forwarding=self.debug_use_port_forwarding,
63            # This runner's namespace created in the secondary cluster,
64            # so it's not reused and must be cleaned up.
65            reuse_namespace=False)
66
67    def cleanup(self):
68        super().cleanup()
69        if hasattr(self, 'secondary_server_runner'):
70            self.secondary_server_runner.cleanup(
71                force=self.force_cleanup, force_namespace=self.force_cleanup)
72
73    def test_failover(self) -> None:
74        with self.subTest('00_create_health_check'):
75            self.td.create_health_check()
76
77        with self.subTest('01_create_backend_services'):
78            self.td.create_backend_service()
79
80        with self.subTest('02_create_url_map'):
81            self.td.create_url_map(self.server_xds_host, self.server_xds_port)
82
83        with self.subTest('03_create_target_proxy'):
84            self.td.create_target_proxy()
85
86        with self.subTest('04_create_forwarding_rule'):
87            self.td.create_forwarding_rule(self.server_xds_port)
88
89        default_test_servers: List[_XdsTestServer]
90        alternate_test_servers: List[_XdsTestServer]
91        with self.subTest('05_start_test_servers'):
92            default_test_servers = self.startTestServers(
93                replica_count=self.REPLICA_COUNT)
94
95            alternate_test_servers = self.startTestServers(
96                server_runner=self.secondary_server_runner)
97
98        with self.subTest('06_add_server_backends_to_backend_services'):
99            self.setupServerBackends(
100                max_rate_per_endpoint=self.MAX_RATE_PER_ENDPOINT)
101            self.setupServerBackends(
102                server_runner=self.secondary_server_runner,
103                max_rate_per_endpoint=self.MAX_RATE_PER_ENDPOINT)
104
105        test_client: _XdsTestClient
106        with self.subTest('07_start_test_client'):
107            test_client = self.startTestClient(default_test_servers[0])
108
109        with self.subTest('08_test_client_xds_config_exists'):
110            self.assertXdsConfigExists(test_client)
111
112        with self.subTest('09_primary_locality_receives_requests'):
113            self.assertRpcsEventuallyGoToGivenServers(test_client,
114                                                      default_test_servers)
115
116        with self.subTest(
117                '10_secondary_locality_receives_no_requests_on_partial_primary_failure'
118        ):
119            default_test_servers[0].set_not_serving()
120            self.assertRpcsEventuallyGoToGivenServers(test_client,
121                                                      default_test_servers[1:])
122
123        with self.subTest('11_gentle_failover'):
124            default_test_servers[1].set_not_serving()
125            self.assertRpcsEventuallyGoToGivenServers(
126                test_client, default_test_servers[2:] + alternate_test_servers)
127
128        with self.subTest(
129                '12_secondary_locality_receives_requests_on_primary_failure'):
130            default_test_servers[2].set_not_serving()
131            self.assertRpcsEventuallyGoToGivenServers(test_client,
132                                                      alternate_test_servers)
133
134        with self.subTest('13_traffic_resumes_to_healthy_backends'):
135            for i in range(self.REPLICA_COUNT):
136                default_test_servers[i].set_serving()
137            self.assertRpcsEventuallyGoToGivenServers(test_client,
138                                                      default_test_servers)
139
140
141if __name__ == '__main__':
142    absltest.main(failfast=True)
143