1# Copyright 2020 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""
15Provides an interface to xDS Test Client running remotely.
16"""
17import datetime
18import functools
19import logging
20from typing import Iterable, List, Optional
21
22from framework.helpers import retryers
23import framework.rpc
24from framework.rpc import grpc_channelz
25from framework.rpc import grpc_csds
26from framework.rpc import grpc_testing
27
28logger = logging.getLogger(__name__)
29
30# Type aliases
31_timedelta = datetime.timedelta
32_LoadBalancerStatsServiceClient = grpc_testing.LoadBalancerStatsServiceClient
33_XdsUpdateClientConfigureServiceClient = grpc_testing.XdsUpdateClientConfigureServiceClient
34_ChannelzServiceClient = grpc_channelz.ChannelzServiceClient
35_ChannelzChannel = grpc_channelz.Channel
36_ChannelzChannelState = grpc_channelz.ChannelState
37_ChannelzSubchannel = grpc_channelz.Subchannel
38_ChannelzSocket = grpc_channelz.Socket
39_CsdsClient = grpc_csds.CsdsClient
40
41
42class XdsTestClient(framework.rpc.grpc.GrpcApp):
43    """
44    Represents RPC services implemented in Client component of the xds test app.
45    https://github.com/grpc/grpc/blob/master/doc/xds-test-descriptions.md#client
46    """
47    # A unique string identifying each client replica. Used in logging.
48    hostname: str
49
50    def __init__(self,
51                 *,
52                 ip: str,
53                 rpc_port: int,
54                 server_target: str,
55                 hostname: str,
56                 rpc_host: Optional[str] = None,
57                 maintenance_port: Optional[int] = None):
58        super().__init__(rpc_host=(rpc_host or ip))
59        self.ip = ip
60        self.rpc_port = rpc_port
61        self.server_target = server_target
62        self.maintenance_port = maintenance_port or rpc_port
63        self.hostname = hostname
64
65    @property
66    @functools.lru_cache(None)
67    def load_balancer_stats(self) -> _LoadBalancerStatsServiceClient:
68        return _LoadBalancerStatsServiceClient(
69            self._make_channel(self.rpc_port),
70            log_target=f'{self.hostname}:{self.rpc_port}')
71
72    @property
73    @functools.lru_cache(None)
74    def update_config(self):
75        return _XdsUpdateClientConfigureServiceClient(
76            self._make_channel(self.rpc_port),
77            log_target=f'{self.hostname}:{self.rpc_port}')
78
79    @property
80    @functools.lru_cache(None)
81    def channelz(self) -> _ChannelzServiceClient:
82        return _ChannelzServiceClient(
83            self._make_channel(self.maintenance_port),
84            log_target=f'{self.hostname}:{self.maintenance_port}')
85
86    @property
87    @functools.lru_cache(None)
88    def csds(self) -> _CsdsClient:
89        return _CsdsClient(
90            self._make_channel(self.maintenance_port),
91            log_target=f'{self.hostname}:{self.maintenance_port}')
92
93    def get_load_balancer_stats(
94        self,
95        *,
96        num_rpcs: int,
97        timeout_sec: Optional[int] = None,
98    ) -> grpc_testing.LoadBalancerStatsResponse:
99        """
100        Shortcut to LoadBalancerStatsServiceClient.get_client_stats()
101        """
102        return self.load_balancer_stats.get_client_stats(
103            num_rpcs=num_rpcs, timeout_sec=timeout_sec)
104
105    def get_load_balancer_accumulated_stats(
106        self,
107        *,
108        timeout_sec: Optional[int] = None,
109    ) -> grpc_testing.LoadBalancerAccumulatedStatsResponse:
110        """Shortcut to LoadBalancerStatsServiceClient.get_client_accumulated_stats()"""
111        return self.load_balancer_stats.get_client_accumulated_stats(
112            timeout_sec=timeout_sec)
113
114    def wait_for_active_server_channel(self) -> _ChannelzChannel:
115        """Wait for the channel to the server to transition to READY.
116
117        Raises:
118            GrpcApp.NotFound: If the channel never transitioned to READY.
119        """
120        return self.wait_for_server_channel_state(_ChannelzChannelState.READY)
121
122    def get_active_server_channel_socket(self) -> _ChannelzSocket:
123        channel = self.find_server_channel_with_state(
124            _ChannelzChannelState.READY)
125        # Get the first subchannel of the active channel to the server.
126        logger.debug(
127            '[%s] Retrieving client -> server socket, '
128            'channel_id: %s, subchannel: %s', self.hostname,
129            channel.ref.channel_id, channel.subchannel_ref[0].name)
130        subchannel, *subchannels = list(
131            self.channelz.list_channel_subchannels(channel))
132        if subchannels:
133            logger.warning('[%s] Unexpected subchannels: %r', self.hostname,
134                           subchannels)
135        # Get the first socket of the subchannel
136        socket, *sockets = list(
137            self.channelz.list_subchannels_sockets(subchannel))
138        if sockets:
139            logger.warning('[%s] Unexpected sockets: %r', self.hostname,
140                           subchannels)
141        logger.debug('[%s] Found client -> server socket: %s', self.hostname,
142                     socket.ref.name)
143        return socket
144
145    def wait_for_server_channel_state(
146            self,
147            state: _ChannelzChannelState,
148            *,
149            timeout: Optional[_timedelta] = None,
150            rpc_deadline: Optional[_timedelta] = None) -> _ChannelzChannel:
151        # When polling for a state, prefer smaller wait times to avoid
152        # exhausting all allowed time on a single long RPC.
153        if rpc_deadline is None:
154            rpc_deadline = _timedelta(seconds=30)
155
156        # Fine-tuned to wait for the channel to the server.
157        retryer = retryers.exponential_retryer_with_timeout(
158            wait_min=_timedelta(seconds=10),
159            wait_max=_timedelta(seconds=25),
160            timeout=_timedelta(minutes=5) if timeout is None else timeout)
161
162        logger.info('[%s] Waiting to report a %s channel to %s', self.hostname,
163                    _ChannelzChannelState.Name(state), self.server_target)
164        channel = retryer(self.find_server_channel_with_state,
165                          state,
166                          rpc_deadline=rpc_deadline)
167        logger.info('[%s] Channel to %s transitioned to state %s: %s',
168                    self.hostname, self.server_target,
169                    _ChannelzChannelState.Name(state),
170                    _ChannelzServiceClient.channel_repr(channel))
171        return channel
172
173    def find_server_channel_with_state(
174            self,
175            state: _ChannelzChannelState,
176            *,
177            rpc_deadline: Optional[_timedelta] = None,
178            check_subchannel=True) -> _ChannelzChannel:
179        rpc_params = {}
180        if rpc_deadline is not None:
181            rpc_params['deadline_sec'] = rpc_deadline.total_seconds()
182
183        for channel in self.get_server_channels(**rpc_params):
184            channel_state: _ChannelzChannelState = channel.data.state.state
185            logger.info('[%s] Server channel: %s', self.hostname,
186                        _ChannelzServiceClient.channel_repr(channel))
187            if channel_state is state:
188                if check_subchannel:
189                    # When requested, check if the channel has at least
190                    # one subchannel in the requested state.
191                    try:
192                        subchannel = self.find_subchannel_with_state(
193                            channel, state, **rpc_params)
194                        logger.info(
195                            '[%s] Found subchannel in state %s: %s',
196                            self.hostname, _ChannelzChannelState.Name(state),
197                            _ChannelzServiceClient.subchannel_repr(subchannel))
198                    except self.NotFound as e:
199                        # Otherwise, keep searching.
200                        logger.info(e.message)
201                        continue
202                return channel
203
204        raise self.NotFound(
205            f'[{self.hostname}] Client has no '
206            f'{_ChannelzChannelState.Name(state)} channel with the server')
207
208    def get_server_channels(self, **kwargs) -> Iterable[_ChannelzChannel]:
209        return self.channelz.find_channels_for_target(self.server_target,
210                                                      **kwargs)
211
212    def find_subchannel_with_state(self, channel: _ChannelzChannel,
213                                   state: _ChannelzChannelState,
214                                   **kwargs) -> _ChannelzSubchannel:
215        subchannels = self.channelz.list_channel_subchannels(channel, **kwargs)
216        for subchannel in subchannels:
217            if subchannel.data.state.state is state:
218                return subchannel
219
220        raise self.NotFound(f'[{self.hostname}] Not found '
221                            f'a {_ChannelzChannelState.Name(state)} subchannel '
222                            f'for channel_id {channel.ref.channel_id}')
223
224    def find_subchannels_with_state(self, state: _ChannelzChannelState,
225                                    **kwargs) -> List[_ChannelzSubchannel]:
226        subchannels = []
227        for channel in self.channelz.find_channels_for_target(
228                self.server_target, **kwargs):
229            for subchannel in self.channelz.list_channel_subchannels(
230                    channel, **kwargs):
231                if subchannel.data.state.state is state:
232                    subchannels.append(subchannel)
233        return subchannels
234