1# Copyright 2021 The gRPC Authors
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""A test framework built for urlMap related xDS test cases."""
15
16import abc
17from dataclasses import dataclass
18import datetime
19import json
20import os
21import re
22import sys
23import time
24from typing import Any, Iterable, Mapping, Optional, Tuple
25import unittest
26
27from absl import flags
28from absl import logging
29from absl.testing import absltest
30from google.protobuf import json_format
31import grpc
32
33from framework import xds_k8s_testcase
34from framework import xds_url_map_test_resources
35from framework.helpers import retryers
36from framework.helpers import skips
37from framework.infrastructure import k8s
38from framework.test_app import client_app
39from framework.test_app.runners.k8s import k8s_xds_client_runner
40
41# Load existing flags
42flags.adopt_module_key_flags(xds_k8s_testcase)
43flags.adopt_module_key_flags(xds_url_map_test_resources)
44
45# Define urlMap specific flags
46QPS = flags.DEFINE_integer('qps', default=25, help='The QPS client is sending')
47
48# Test configs
49_URL_MAP_PROPAGATE_TIMEOUT_SEC = 600
50# With the per-run IAM change, the first xDS response has a several minutes
51# delay. We want to increase the interval, reduce the log spam.
52_URL_MAP_PROPAGATE_CHECK_INTERVAL_SEC = 15
53URL_MAP_TESTCASE_FILE_SUFFIX = '_test.py'
54_CLIENT_CONFIGURE_WAIT_SEC = 2
55
56# Type aliases
57XdsTestClient = client_app.XdsTestClient
58GcpResourceManager = xds_url_map_test_resources.GcpResourceManager
59HostRule = xds_url_map_test_resources.HostRule
60PathMatcher = xds_url_map_test_resources.PathMatcher
61_KubernetesClientRunner = k8s_xds_client_runner.KubernetesClientRunner
62JsonType = Any
63_timedelta = datetime.timedelta
64
65# ProtoBuf translatable RpcType enums
66RpcTypeUnaryCall = 'UNARY_CALL'
67RpcTypeEmptyCall = 'EMPTY_CALL'
68
69
70def _split_camel(s: str, delimiter: str = '-') -> str:
71    """Turn camel case name to snake-case-like name."""
72    return ''.join(delimiter + c.lower() if c.isupper() else c
73                   for c in s).lstrip(delimiter)
74
75
76class DumpedXdsConfig(dict):
77    """A convenience class to check xDS config.
78
79    Feel free to add more pre-compute fields.
80    """
81
82    def __init__(self, xds_json: JsonType):  # pylint: disable=too-many-branches
83        super().__init__(xds_json)
84        self.json_config = xds_json
85        self.lds = None
86        self.rds = None
87        self.rds_version = None
88        self.cds = []
89        self.eds = []
90        self.endpoints = []
91        for xds_config in self.get('xdsConfig', []):
92            try:
93                if 'listenerConfig' in xds_config:
94                    self.lds = xds_config['listenerConfig']['dynamicListeners'][
95                        0]['activeState']['listener']
96                elif 'routeConfig' in xds_config:
97                    self.rds = xds_config['routeConfig']['dynamicRouteConfigs'][
98                        0]['routeConfig']
99                    self.rds_version = xds_config['routeConfig'][
100                        'dynamicRouteConfigs'][0]['versionInfo']
101                elif 'clusterConfig' in xds_config:
102                    for cluster in xds_config['clusterConfig'][
103                            'dynamicActiveClusters']:
104                        self.cds.append(cluster['cluster'])
105                elif 'endpointConfig' in xds_config:
106                    for endpoint in xds_config['endpointConfig'][
107                            'dynamicEndpointConfigs']:
108                        self.eds.append(endpoint['endpointConfig'])
109            # TODO(lidiz) reduce the catch to LookupError
110            except Exception as e:  # pylint: disable=broad-except
111                logging.debug('Parsing dumped xDS config failed with %s: %s',
112                              type(e), e)
113        for generic_xds_config in self.get('genericXdsConfigs', []):
114            try:
115                if re.search(r'\.Listener$', generic_xds_config['typeUrl']):
116                    self.lds = generic_xds_config["xdsConfig"]
117                elif re.search(r'\.RouteConfiguration$',
118                               generic_xds_config['typeUrl']):
119                    self.rds = generic_xds_config["xdsConfig"]
120                    self.rds_version = generic_xds_config["versionInfo"]
121                elif re.search(r'\.Cluster$', generic_xds_config['typeUrl']):
122                    self.cds.append(generic_xds_config["xdsConfig"])
123                elif re.search(r'\.ClusterLoadAssignment$',
124                               generic_xds_config['typeUrl']):
125                    self.eds.append(generic_xds_config["xdsConfig"])
126            # TODO(lidiz) reduce the catch to LookupError
127            except Exception as e:  # pylint: disable=broad-except
128                logging.debug('Parsing dumped xDS config failed with %s: %s',
129                              type(e), e)
130        for endpoint_config in self.eds:
131            for endpoint in endpoint_config.get('endpoints', {}):
132                for lb_endpoint in endpoint.get('lbEndpoints', {}):
133                    try:
134                        if lb_endpoint['healthStatus'] == 'HEALTHY':
135                            self.endpoints.append(
136                                '%s:%s' % (lb_endpoint['endpoint']['address']
137                                           ['socketAddress']['address'],
138                                           lb_endpoint['endpoint']['address']
139                                           ['socketAddress']['portValue']))
140                    # TODO(lidiz) reduce the catch to LookupError
141                    except Exception as e:  # pylint: disable=broad-except
142                        logging.debug('Parse endpoint failed with %s: %s',
143                                      type(e), e)
144
145    def __str__(self) -> str:
146        return json.dumps(self, indent=2)
147
148
149class RpcDistributionStats:
150    """A convenience class to check RPC distribution.
151
152    Feel free to add more pre-compute fields.
153    """
154    num_failures: int
155    num_oks: int
156    default_service_rpc_count: int
157    alternative_service_rpc_count: int
158    unary_call_default_service_rpc_count: int
159    empty_call_default_service_rpc_count: int
160    unary_call_alternative_service_rpc_count: int
161    empty_call_alternative_service_rpc_count: int
162
163    def __init__(self, json_lb_stats: JsonType):
164        self.num_failures = json_lb_stats.get('numFailures', 0)
165
166        self.num_peers = 0
167        self.num_oks = 0
168        self.default_service_rpc_count = 0
169        self.alternative_service_rpc_count = 0
170        self.unary_call_default_service_rpc_count = 0
171        self.empty_call_default_service_rpc_count = 0
172        self.unary_call_alternative_service_rpc_count = 0
173        self.empty_call_alternative_service_rpc_count = 0
174        self.raw = json_lb_stats
175
176        if 'rpcsByPeer' in json_lb_stats:
177            self.num_peers = len(json_lb_stats['rpcsByPeer'])
178        if 'rpcsByMethod' in json_lb_stats:
179            for rpc_type in json_lb_stats['rpcsByMethod']:
180                for peer in json_lb_stats['rpcsByMethod'][rpc_type][
181                        'rpcsByPeer']:
182                    count = json_lb_stats['rpcsByMethod'][rpc_type][
183                        'rpcsByPeer'][peer]
184                    self.num_oks += count
185                    if rpc_type == 'UnaryCall':
186                        if 'alternative' in peer:
187                            self.unary_call_alternative_service_rpc_count = count
188                            self.alternative_service_rpc_count += count
189                        else:
190                            self.unary_call_default_service_rpc_count = count
191                            self.default_service_rpc_count += count
192                    else:
193                        if 'alternative' in peer:
194                            self.empty_call_alternative_service_rpc_count = count
195                            self.alternative_service_rpc_count += count
196                        else:
197                            self.empty_call_default_service_rpc_count = count
198                            self.default_service_rpc_count += count
199
200
201@dataclass
202class ExpectedResult:
203    """Describes the expected result of assertRpcStatusCode method below."""
204    rpc_type: str = RpcTypeUnaryCall
205    status_code: grpc.StatusCode = grpc.StatusCode.OK
206    ratio: float = 1
207
208
209class _MetaXdsUrlMapTestCase(type):
210    """Tracking test case subclasses."""
211
212    # Automatic discover of all subclasses
213    _test_case_classes = []
214    _test_case_names = set()
215    # Keep track of started and finished test cases, so we know when to setup
216    # and tear down GCP resources.
217    _started_test_cases = set()
218    _finished_test_cases = set()
219
220    def __new__(cls, name: str, bases: Iterable[Any],
221                attrs: Mapping[str, Any]) -> Any:
222        # Hand over the tracking objects
223        attrs['test_case_classes'] = cls._test_case_classes
224        attrs['test_case_names'] = cls._test_case_names
225        attrs['started_test_cases'] = cls._started_test_cases
226        attrs['finished_test_cases'] = cls._finished_test_cases
227        # Handle the test name reflection
228        module_name = os.path.split(
229            sys.modules[attrs['__module__']].__file__)[-1]
230        if module_name.endswith(URL_MAP_TESTCASE_FILE_SUFFIX):
231            module_name = module_name.replace(URL_MAP_TESTCASE_FILE_SUFFIX, '')
232        attrs['short_module_name'] = module_name.replace('_', '-')
233        # Create the class and track
234        new_class = type.__new__(cls, name, bases, attrs)
235        if name.startswith('Test'):
236            cls._test_case_names.add(name)
237            cls._test_case_classes.append(new_class)
238        else:
239            logging.debug('Skipping test case class: %s', name)
240        return new_class
241
242
243class XdsUrlMapTestCase(absltest.TestCase, metaclass=_MetaXdsUrlMapTestCase):
244    """XdsUrlMapTestCase is the base class for urlMap related tests.
245
246    The subclass is expected to implement 3 methods:
247
248    - url_map_change: Updates the urlMap components for this test case
249    - xds_config_validate: Validates if the client received legit xDS configs
250    - rpc_distribution_validate: Validates if the routing behavior is correct
251    """
252
253    test_client_runner: Optional[_KubernetesClientRunner] = None
254
255    @staticmethod
256    def is_supported(config: skips.TestConfig) -> bool:
257        """Allow the test case to decide whether it supports the given config.
258
259        Returns:
260          A bool indicates if the given config is supported.
261        """
262        del config
263        return True
264
265    @staticmethod
266    def client_init_config(rpc: str, metadata: str) -> Tuple[str, str]:
267        """Updates the initial RPC configs for this test case.
268
269        Each test case will start a test client. The client takes RPC configs
270        and starts to send RPCs immediately. The config returned by this
271        function will be used to replace the default configs.
272
273        The default configs are passed in as arguments, so this method can
274        modify part of them.
275
276        Args:
277            rpc: The default rpc config, specifying RPCs to send, format
278            'UnaryCall,EmptyCall'
279            metadata: The metadata config, specifying metadata to send with each
280            RPC, format 'EmptyCall:key1:value1,UnaryCall:key2:value2'.
281
282        Returns:
283            A tuple contains the updated rpc and metadata config.
284        """
285        return rpc, metadata
286
287    @staticmethod
288    @abc.abstractmethod
289    def url_map_change(
290            host_rule: HostRule,
291            path_matcher: PathMatcher) -> Tuple[HostRule, PathMatcher]:
292        """Updates the dedicated urlMap components for this test case.
293
294        Each test case will have a dedicated HostRule, where the hostname is
295        generated from the test case name. The HostRule will be linked to a
296        PathMatcher, where stores the routing logic.
297
298        Args:
299            host_rule: A HostRule GCP resource as a JSON dict.
300            path_matcher: A PathMatcher GCP resource as a JSON dict.
301
302        Returns:
303            A tuple contains the updated version of given HostRule and
304            PathMatcher.
305        """
306
307    @abc.abstractmethod
308    def xds_config_validate(self, xds_config: DumpedXdsConfig) -> None:
309        """Validates received xDS config, if anything is wrong, raise.
310
311        This stage only ends when the control plane failed to send a valid
312        config within a given time range, like 600s.
313
314        Args:
315            xds_config: A DumpedXdsConfig instance can be used as a JSON dict,
316              but also provides helper fields for commonly checked xDS config.
317        """
318
319    @abc.abstractmethod
320    def rpc_distribution_validate(self, test_client: XdsTestClient) -> None:
321        """Validates the routing behavior, if any is wrong, raise.
322
323        Args:
324            test_client: A XdsTestClient instance for all sorts of end2end testing.
325        """
326
327    @classmethod
328    def hostname(cls):
329        return "%s.%s:%s" % (cls.short_module_name, _split_camel(
330            cls.__name__), GcpResourceManager().server_xds_port)
331
332    @classmethod
333    def path_matcher_name(cls):
334        # Path matcher name must match r'(?:[a-z](?:[-a-z0-9]{0,61}[a-z0-9])?)'
335        return "%s-%s-pm" % (cls.short_module_name, _split_camel(cls.__name__))
336
337    @classmethod
338    def setUpClass(cls):
339        logging.info('----- Testing %s -----', cls.__name__)
340        logging.info('Logs timezone: %s', time.localtime().tm_zone)
341
342        # Raises unittest.SkipTest if given client/server/version does not
343        # support current test case.
344        skips.evaluate_test_config(cls.is_supported)
345
346        # Configure cleanup to run after all tests regardless of
347        # whether setUpClass failed.
348        cls.addClassCleanup(cls.cleanupAfterTests)
349
350        if not cls.started_test_cases:
351            # Create the GCP resource once before the first test start
352            GcpResourceManager().setup(cls.test_case_classes)
353        cls.started_test_cases.add(cls.__name__)
354
355        # Create the test case's own client runner with it's own namespace,
356        # enables concurrent running with other test cases.
357        cls.test_client_runner = GcpResourceManager().create_test_client_runner(
358        )
359        # Start the client, and allow the test to override the initial RPC config.
360        rpc, metadata = cls.client_init_config(rpc="UnaryCall,EmptyCall",
361                                               metadata="")
362        cls.test_client = cls.test_client_runner.run(
363            server_target=f'xds:///{cls.hostname()}',
364            rpc=rpc,
365            metadata=metadata,
366            qps=QPS.value,
367            print_response=True)
368
369    @classmethod
370    def cleanupAfterTests(cls):
371        logging.info('----- TestCase %s teardown -----', cls.__name__)
372        client_restarts: int = 0
373        if cls.test_client_runner:
374            try:
375                logging.debug('Getting pods restart times')
376                client_restarts = cls.test_client_runner.get_pod_restarts(
377                    cls.test_client_runner.deployment)
378            except (retryers.RetryError, k8s.NotFound) as e:
379                logging.exception(e)
380
381        cls.finished_test_cases.add(cls.__name__)
382        # Whether to clean up shared pre-provisioned infrastructure too.
383        # We only do it after all tests are finished.
384        cleanup_all = cls.finished_test_cases == cls.test_case_names
385
386        # Graceful cleanup: try three times, and don't fail the test on
387        # a cleanup failure.
388        retryer = retryers.constant_retryer(wait_fixed=_timedelta(seconds=10),
389                                            attempts=3,
390                                            log_level=logging.INFO)
391        try:
392            retryer(cls._cleanup, cleanup_all)
393        except retryers.RetryError:
394            logging.exception('Got error during teardown')
395        finally:
396            if hasattr(cls, 'test_client_runner') and cls.test_client_runner:
397                logging.info('----- Test client logs -----')
398                cls.test_client_runner.logs_explorer_run_history_links()
399
400            # Fail if any of the pods restarted.
401            error_msg = (
402                'Client pods unexpectedly restarted'
403                f' {client_restarts} times during test.'
404                ' In most cases, this is caused by the test client app crash.')
405            assert client_restarts == 0, error_msg
406
407    @classmethod
408    def _cleanup(cls, cleanup_all: bool = False):
409        if cls.test_client_runner:
410            cls.test_client_runner.cleanup(force=True, force_namespace=True)
411        if cleanup_all:
412            GcpResourceManager().cleanup()
413
414    def _fetch_and_check_xds_config(self):
415        # TODO(lidiz) find another way to store last seen xDS config
416        # Cleanup state for this attempt
417        self._xds_json_config = None  # pylint: disable=attribute-defined-outside-init
418        # Fetch client config
419        config = self.test_client.csds.fetch_client_status(
420            log_level=logging.INFO)
421        self.assertIsNotNone(config)
422        # Found client config, test it.
423        self._xds_json_config = json_format.MessageToDict(config)  # pylint: disable=attribute-defined-outside-init
424        # Execute the child class provided validation logic
425        self.xds_config_validate(DumpedXdsConfig(self._xds_json_config))
426
427    def run(self, result: unittest.TestResult = None) -> None:
428        """Abort this test case if CSDS check is failed.
429
430        This prevents the test runner to waste time on RPC distribution test,
431        and yields clearer signal.
432        """
433        if result.failures or result.errors:
434            logging.info('Aborting %s', self.__class__.__name__)
435        else:
436            super().run(result)
437
438    def test_client_config(self):
439        retryer = retryers.constant_retryer(
440            wait_fixed=datetime.timedelta(
441                seconds=_URL_MAP_PROPAGATE_CHECK_INTERVAL_SEC),
442            timeout=datetime.timedelta(seconds=_URL_MAP_PROPAGATE_TIMEOUT_SEC),
443            logger=logging,
444            log_level=logging.INFO)
445        try:
446            retryer(self._fetch_and_check_xds_config)
447        finally:
448            logging.info(
449                'latest xDS config:\n%s',
450                GcpResourceManager().td.compute.resource_pretty_format(
451                    self._xds_json_config))
452
453    def test_rpc_distribution(self):
454        self.rpc_distribution_validate(self.test_client)
455
456    @staticmethod
457    def configure_and_send(test_client: XdsTestClient,
458                           *,
459                           rpc_types: Iterable[str],
460                           metadata: Optional[Iterable[Tuple[str, str,
461                                                             str]]] = None,
462                           app_timeout: Optional[int] = None,
463                           num_rpcs: int) -> RpcDistributionStats:
464        test_client.update_config.configure(rpc_types=rpc_types,
465                                            metadata=metadata,
466                                            app_timeout=app_timeout)
467        # Configure RPC might race with get stats RPC on slower machines.
468        time.sleep(_CLIENT_CONFIGURE_WAIT_SEC)
469        json_lb_stats = json_format.MessageToDict(
470            test_client.get_load_balancer_stats(num_rpcs=num_rpcs))
471        logging.info(
472            'Received LoadBalancerStatsResponse from test client %s:\n%s',
473            test_client.hostname, json.dumps(json_lb_stats, indent=2))
474        return RpcDistributionStats(json_lb_stats)
475
476    def assertNumEndpoints(self, xds_config: DumpedXdsConfig, k: int) -> None:
477        self.assertLen(
478            xds_config.endpoints, k,
479            f'insufficient endpoints in EDS: want={k} seen={xds_config.endpoints}'
480        )
481
482    def assertRpcStatusCode(  # pylint: disable=too-many-locals
483            self, test_client: XdsTestClient, *,
484            expected: Iterable[ExpectedResult], length: int,
485            tolerance: float) -> None:
486        """Assert the distribution of RPC statuses over a period of time."""
487        # Sending with pre-set QPS for a period of time
488        before_stats = test_client.get_load_balancer_accumulated_stats()
489        logging.info(
490            'Received LoadBalancerAccumulatedStatsResponse from test client %s: before:\n%s',
491            test_client.hostname, before_stats)
492        time.sleep(length)
493        after_stats = test_client.get_load_balancer_accumulated_stats()
494        logging.info(
495            'Received LoadBalancerAccumulatedStatsResponse from test client %s: after: \n%s',
496            test_client.hostname, after_stats)
497
498        # Validate the diff
499        for expected_result in expected:
500            rpc = expected_result.rpc_type
501            status = expected_result.status_code.value[0]
502            # Compute observation
503            # ProtoBuf messages has special magic dictionary that we don't need
504            # to catch exceptions:
505            # https://developers.google.com/protocol-buffers/docs/reference/python-generated#undefined
506            seen_after = after_stats.stats_per_method[rpc].result[status]
507            seen_before = before_stats.stats_per_method[rpc].result[status]
508            seen = seen_after - seen_before
509            # Compute total number of RPC started
510            stats_per_method_after = after_stats.stats_per_method.get(
511                rpc, {}).result.items()
512            total_after = sum(
513                x[1] for x in stats_per_method_after)  # (status_code, count)
514            stats_per_method_before = before_stats.stats_per_method.get(
515                rpc, {}).result.items()
516            total_before = sum(
517                x[1] for x in stats_per_method_before)  # (status_code, count)
518            total = total_after - total_before
519            # Compute and validate the number
520            want = total * expected_result.ratio
521            diff_ratio = abs(seen - want) / total
522            self.assertLessEqual(
523                diff_ratio, tolerance,
524                (f'Expect rpc [{rpc}] to return '
525                 f'[{expected_result.status_code}] at '
526                 f'{expected_result.ratio:.2f} ratio: '
527                 f'seen={seen} want={want} total={total} '
528                 f'diff_ratio={diff_ratio:.4f} > {tolerance:.2f}'))
529