xref: /aosp_15_r20/external/grpc-grpc/test/core/client_channel/lb_policy/lb_policy_test_lib.h (revision cc02d7e222339f7a4f6ba5f422e6413f4bd931f2)
1 //
2 // Copyright 2022 gRPC authors.
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //     http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //
16 
17 #ifndef GRPC_TEST_CORE_CLIENT_CHANNEL_LB_POLICY_LB_POLICY_TEST_LIB_H
18 #define GRPC_TEST_CORE_CLIENT_CHANNEL_LB_POLICY_LB_POLICY_TEST_LIB_H
19 
20 #include <inttypes.h>
21 #include <stddef.h>
22 
23 #include <algorithm>
24 #include <chrono>
25 #include <deque>
26 #include <functional>
27 #include <map>
28 #include <memory>
29 #include <set>
30 #include <string>
31 #include <tuple>
32 #include <type_traits>
33 #include <utility>
34 #include <vector>
35 
36 #include "absl/base/thread_annotations.h"
37 #include "absl/functional/any_invocable.h"
38 #include "absl/status/status.h"
39 #include "absl/status/statusor.h"
40 #include "absl/strings/str_format.h"
41 #include "absl/strings/str_join.h"
42 #include "absl/strings/string_view.h"
43 #include "absl/synchronization/notification.h"
44 #include "absl/types/optional.h"
45 #include "absl/types/span.h"
46 #include "absl/types/variant.h"
47 #include "gmock/gmock.h"
48 #include "gtest/gtest.h"
49 
50 #include <grpc/event_engine/event_engine.h>
51 #include <grpc/grpc.h>
52 #include <grpc/support/alloc.h>
53 #include <grpc/support/log.h>
54 #include <grpc/support/port_platform.h>
55 
56 #include "src/core/client_channel/client_channel_internal.h"
57 #include "src/core/client_channel/subchannel_interface_internal.h"
58 #include "src/core/client_channel/subchannel_pool_interface.h"
59 #include "src/core/lib/address_utils/parse_address.h"
60 #include "src/core/lib/address_utils/sockaddr_utils.h"
61 #include "src/core/lib/channel/channel_args.h"
62 #include "src/core/lib/config/core_configuration.h"
63 #include "src/core/lib/event_engine/default_event_engine.h"
64 #include "src/core/lib/gprpp/debug_location.h"
65 #include "src/core/lib/gprpp/match.h"
66 #include "src/core/lib/gprpp/orphanable.h"
67 #include "src/core/lib/gprpp/ref_counted_ptr.h"
68 #include "src/core/lib/gprpp/sync.h"
69 #include "src/core/lib/gprpp/time.h"
70 #include "src/core/lib/gprpp/unique_type_name.h"
71 #include "src/core/lib/gprpp/work_serializer.h"
72 #include "src/core/lib/iomgr/exec_ctx.h"
73 #include "src/core/lib/iomgr/resolved_address.h"
74 #include "src/core/lib/json/json.h"
75 #include "src/core/lib/security/credentials/credentials.h"
76 #include "src/core/lib/transport/connectivity_state.h"
77 #include "src/core/lib/uri/uri_parser.h"
78 #include "src/core/load_balancing/backend_metric_data.h"
79 #include "src/core/load_balancing/health_check_client_internal.h"
80 #include "src/core/load_balancing/lb_policy.h"
81 #include "src/core/load_balancing/lb_policy_registry.h"
82 #include "src/core/load_balancing/oob_backend_metric.h"
83 #include "src/core/load_balancing/oob_backend_metric_internal.h"
84 #include "src/core/load_balancing/subchannel_interface.h"
85 #include "src/core/resolver/endpoint_addresses.h"
86 #include "src/core/service_config/service_config_call_data.h"
87 #include "test/core/event_engine/event_engine_test_utils.h"
88 #include "test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.h"
89 #include "test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.pb.h"
90 
91 namespace grpc_core {
92 namespace testing {
93 
94 class LoadBalancingPolicyTest : public ::testing::Test {
95  protected:
96   using CallAttributes =
97       std::vector<ServiceConfigCallData::CallAttributeInterface*>;
98 
99   // Channel-level subchannel state for a specific address and channel args.
100   // This is analogous to the real subchannel in the ClientChannel code.
101   class SubchannelState {
102    public:
103     // A fake SubchannelInterface object, to be returned to the LB
104     // policy when it calls the helper's CreateSubchannel() method.
105     // There may be multiple FakeSubchannel objects associated with a
106     // given SubchannelState object.
107     class FakeSubchannel : public SubchannelInterface {
108      public:
FakeSubchannel(SubchannelState * state)109       explicit FakeSubchannel(SubchannelState* state) : state_(state) {}
110 
~FakeSubchannel()111       ~FakeSubchannel() override {
112         if (orca_watcher_ != nullptr) {
113           MutexLock lock(&state_->backend_metric_watcher_mu_);
114           state_->orca_watchers_.erase(orca_watcher_.get());
115         }
116         for (const auto& p : watcher_map_) {
117           state_->state_tracker_.RemoveWatcher(p.second);
118         }
119       }
120 
state()121       SubchannelState* state() const { return state_; }
122 
123      private:
124       // Converts between
125       // SubchannelInterface::ConnectivityStateWatcherInterface and
126       // ConnectivityStateWatcherInterface.
127       //
128       // We support both unique_ptr<> and shared_ptr<>, since raw
129       // connectivity watches use the latter but health watches use the
130       // former.
131       // TODO(roth): Clean this up.
132       class WatcherWrapper : public AsyncConnectivityStateWatcherInterface {
133        public:
WatcherWrapper(std::shared_ptr<WorkSerializer> work_serializer,std::unique_ptr<SubchannelInterface::ConnectivityStateWatcherInterface> watcher)134         WatcherWrapper(
135             std::shared_ptr<WorkSerializer> work_serializer,
136             std::unique_ptr<
137                 SubchannelInterface::ConnectivityStateWatcherInterface>
138                 watcher)
139             : AsyncConnectivityStateWatcherInterface(
140                   std::move(work_serializer)),
141               watcher_(std::move(watcher)) {}
142 
WatcherWrapper(std::shared_ptr<WorkSerializer> work_serializer,std::shared_ptr<SubchannelInterface::ConnectivityStateWatcherInterface> watcher)143         WatcherWrapper(
144             std::shared_ptr<WorkSerializer> work_serializer,
145             std::shared_ptr<
146                 SubchannelInterface::ConnectivityStateWatcherInterface>
147                 watcher)
148             : AsyncConnectivityStateWatcherInterface(
149                   std::move(work_serializer)),
150               watcher_(std::move(watcher)) {}
151 
OnConnectivityStateChange(grpc_connectivity_state new_state,const absl::Status & status)152         void OnConnectivityStateChange(grpc_connectivity_state new_state,
153                                        const absl::Status& status) override {
154           gpr_log(GPR_INFO, "notifying watcher: state=%s status=%s",
155                   ConnectivityStateName(new_state), status.ToString().c_str());
156           watcher_->OnConnectivityStateChange(new_state, status);
157         }
158 
159        private:
160         std::shared_ptr<SubchannelInterface::ConnectivityStateWatcherInterface>
161             watcher_;
162       };
163 
WatchConnectivityState(std::unique_ptr<SubchannelInterface::ConnectivityStateWatcherInterface> watcher)164       void WatchConnectivityState(
165           std::unique_ptr<
166               SubchannelInterface::ConnectivityStateWatcherInterface>
167               watcher) override
168           ABSL_EXCLUSIVE_LOCKS_REQUIRED(*state_->test_->work_serializer_) {
169         auto* watcher_ptr = watcher.get();
170         auto watcher_wrapper = MakeOrphanable<WatcherWrapper>(
171             state_->work_serializer(), std::move(watcher));
172         watcher_map_[watcher_ptr] = watcher_wrapper.get();
173         state_->state_tracker_.AddWatcher(GRPC_CHANNEL_SHUTDOWN,
174                                           std::move(watcher_wrapper));
175       }
176 
CancelConnectivityStateWatch(ConnectivityStateWatcherInterface * watcher)177       void CancelConnectivityStateWatch(
178           ConnectivityStateWatcherInterface* watcher) override
179           ABSL_EXCLUSIVE_LOCKS_REQUIRED(*state_->test_->work_serializer_) {
180         auto it = watcher_map_.find(watcher);
181         if (it == watcher_map_.end()) return;
182         state_->state_tracker_.RemoveWatcher(it->second);
183         watcher_map_.erase(it);
184       }
185 
RequestConnection()186       void RequestConnection() override {
187         MutexLock lock(&state_->requested_connection_mu_);
188         state_->requested_connection_ = true;
189       }
190 
AddDataWatcher(std::unique_ptr<DataWatcherInterface> watcher)191       void AddDataWatcher(
192           std::unique_ptr<DataWatcherInterface> watcher) override
193           ABSL_EXCLUSIVE_LOCKS_REQUIRED(*state_->test_->work_serializer_) {
194         MutexLock lock(&state_->backend_metric_watcher_mu_);
195         auto* w =
196             static_cast<InternalSubchannelDataWatcherInterface*>(watcher.get());
197         if (w->type() == OrcaProducer::Type()) {
198           GPR_ASSERT(orca_watcher_ == nullptr);
199           orca_watcher_.reset(static_cast<OrcaWatcher*>(watcher.release()));
200           state_->orca_watchers_.insert(orca_watcher_.get());
201         } else if (w->type() == HealthProducer::Type()) {
202           // TODO(roth): Support health checking in test framework.
203           // For now, we just hard-code this to the raw connectivity state.
204           GPR_ASSERT(health_watcher_ == nullptr);
205           GPR_ASSERT(health_watcher_wrapper_ == nullptr);
206           health_watcher_.reset(static_cast<HealthWatcher*>(watcher.release()));
207           auto connectivity_watcher = health_watcher_->TakeWatcher();
208           auto* connectivity_watcher_ptr = connectivity_watcher.get();
209           auto watcher_wrapper = MakeOrphanable<WatcherWrapper>(
210               state_->work_serializer(), std::move(connectivity_watcher));
211           health_watcher_wrapper_ = watcher_wrapper.get();
212           state_->state_tracker_.AddWatcher(GRPC_CHANNEL_SHUTDOWN,
213                                             std::move(watcher_wrapper));
214           gpr_log(GPR_INFO,
215                   "AddDataWatcher(): added HealthWatch=%p "
216                   "connectivity_watcher=%p watcher_wrapper=%p",
217                   health_watcher_.get(), connectivity_watcher_ptr,
218                   health_watcher_wrapper_);
219         }
220       }
221 
CancelDataWatcher(DataWatcherInterface * watcher)222       void CancelDataWatcher(DataWatcherInterface* watcher) override
223           ABSL_EXCLUSIVE_LOCKS_REQUIRED(*state_->test_->work_serializer_) {
224         MutexLock lock(&state_->backend_metric_watcher_mu_);
225         auto* w = static_cast<InternalSubchannelDataWatcherInterface*>(watcher);
226         if (w->type() == OrcaProducer::Type()) {
227           if (orca_watcher_.get() != static_cast<OrcaWatcher*>(watcher)) return;
228           state_->orca_watchers_.erase(orca_watcher_.get());
229           orca_watcher_.reset();
230         } else if (w->type() == HealthProducer::Type()) {
231           if (health_watcher_.get() != static_cast<HealthWatcher*>(watcher)) {
232             return;
233           }
234           gpr_log(GPR_INFO,
235                   "CancelDataWatcher(): cancelling HealthWatch=%p "
236                   "watcher_wrapper=%p",
237                   health_watcher_.get(), health_watcher_wrapper_);
238           state_->state_tracker_.RemoveWatcher(health_watcher_wrapper_);
239           health_watcher_wrapper_ = nullptr;
240           health_watcher_.reset();
241         }
242       }
243 
244       // Don't need this method, so it's a no-op.
ResetBackoff()245       void ResetBackoff() override {}
246 
247       SubchannelState* state_;
248       std::map<SubchannelInterface::ConnectivityStateWatcherInterface*,
249                WatcherWrapper*>
250           watcher_map_;
251       std::unique_ptr<HealthWatcher> health_watcher_;
252       WatcherWrapper* health_watcher_wrapper_ = nullptr;
253       std::unique_ptr<OrcaWatcher> orca_watcher_;
254     };
255 
SubchannelState(absl::string_view address,LoadBalancingPolicyTest * test)256     SubchannelState(absl::string_view address, LoadBalancingPolicyTest* test)
257         : address_(address),
258           test_(test),
259           state_tracker_("LoadBalancingPolicyTest") {}
260 
address()261     const std::string& address() const { return address_; }
262 
263     void AssertValidConnectivityStateTransition(
264         grpc_connectivity_state from_state, grpc_connectivity_state to_state,
265         SourceLocation location = SourceLocation()) {
266       switch (from_state) {
267         case GRPC_CHANNEL_IDLE:
268           ASSERT_EQ(to_state, GRPC_CHANNEL_CONNECTING)
269               << ConnectivityStateName(from_state) << "=>"
270               << ConnectivityStateName(to_state) << "\n"
271               << location.file() << ":" << location.line();
272           break;
273         case GRPC_CHANNEL_CONNECTING:
274           ASSERT_THAT(to_state,
275                       ::testing::AnyOf(GRPC_CHANNEL_READY,
276                                        GRPC_CHANNEL_TRANSIENT_FAILURE))
277               << ConnectivityStateName(from_state) << "=>"
278               << ConnectivityStateName(to_state) << "\n"
279               << location.file() << ":" << location.line();
280           break;
281         case GRPC_CHANNEL_READY:
282           ASSERT_EQ(to_state, GRPC_CHANNEL_IDLE)
283               << ConnectivityStateName(from_state) << "=>"
284               << ConnectivityStateName(to_state) << "\n"
285               << location.file() << ":" << location.line();
286           break;
287         case GRPC_CHANNEL_TRANSIENT_FAILURE:
288           ASSERT_EQ(to_state, GRPC_CHANNEL_IDLE)
289               << ConnectivityStateName(from_state) << "=>"
290               << ConnectivityStateName(to_state) << "\n"
291               << location.file() << ":" << location.line();
292           break;
293         default:
294           FAIL() << ConnectivityStateName(from_state) << "=>"
295                  << ConnectivityStateName(to_state) << "\n"
296                  << location.file() << ":" << location.line();
297           break;
298       }
299     }
300 
301     // Sets the connectivity state for this subchannel.  The updated state
302     // will be reported to all associated SubchannelInterface objects.
303     void SetConnectivityState(grpc_connectivity_state state,
304                               const absl::Status& status = absl::OkStatus(),
305                               bool validate_state_transition = true,
306                               SourceLocation location = SourceLocation()) {
307       ExecCtx exec_ctx;
308       if (state == GRPC_CHANNEL_TRANSIENT_FAILURE) {
309         EXPECT_FALSE(status.ok())
310             << "bug in test: TRANSIENT_FAILURE must have non-OK status";
311       } else {
312         EXPECT_TRUE(status.ok())
313             << "bug in test: " << ConnectivityStateName(state)
314             << " must have OK status: " << status;
315       }
316       // Updating the state in the state tracker will enqueue
317       // notifications to watchers on the WorkSerializer.  If any
318       // subchannel reports READY, the pick_first leaf policy will then
319       // start a health watch, whose initial notification will also be
320       // scheduled on the WorkSerializer.  We don't want to return until
321       // all of those notifications have been delivered.
322       absl::Notification notification;
323       test_->work_serializer_->Run(
324           [&]() ABSL_EXCLUSIVE_LOCKS_REQUIRED(*test_->work_serializer_) {
325             if (validate_state_transition) {
326               AssertValidConnectivityStateTransition(state_tracker_.state(),
327                                                      state, location);
328             }
329             gpr_log(GPR_INFO, "Setting state on tracker");
330             state_tracker_.SetState(state, status, "set from test");
331             // SetState() enqueued the connectivity state notifications for
332             // the subchannel, so we add another callback to the queue to be
333             // executed after that state notifications has been delivered.
334             gpr_log(GPR_INFO,
335                     "Waiting for state notifications to be delivered");
336             test_->work_serializer_->Run(
337                 [&]() {
338                   gpr_log(GPR_INFO,
339                           "State notifications delivered, waiting for health "
340                           "notifications");
341                   // Now the connectivity state notifications has been
342                   // delivered. If the state reported was READY, then the
343                   // pick_first leaf policy will have started a health watch, so
344                   // we add another callback to the queue to be executed after
345                   // the initial health watch notification has been delivered.
346                   test_->work_serializer_->Run([&]() { notification.Notify(); },
347                                                DEBUG_LOCATION);
348                 },
349                 DEBUG_LOCATION);
350           },
351           DEBUG_LOCATION);
352       notification.WaitForNotification();
353       gpr_log(GPR_INFO, "Health notifications delivered");
354     }
355 
356     // Indicates if any of the associated SubchannelInterface objects
357     // have requested a connection attempt since the last time this
358     // method was called.
ConnectionRequested()359     bool ConnectionRequested() {
360       MutexLock lock(&requested_connection_mu_);
361       return std::exchange(requested_connection_, false);
362     }
363 
364     // To be invoked by FakeHelper.
CreateSubchannel()365     RefCountedPtr<SubchannelInterface> CreateSubchannel() {
366       return MakeRefCounted<FakeSubchannel>(this);
367     }
368 
369     // Sends an OOB backend metric report to all watchers.
SendOobBackendMetricReport(const BackendMetricData & backend_metrics)370     void SendOobBackendMetricReport(const BackendMetricData& backend_metrics) {
371       MutexLock lock(&backend_metric_watcher_mu_);
372       for (const auto* watcher : orca_watchers_) {
373         watcher->watcher()->OnBackendMetricReport(backend_metrics);
374       }
375     }
376 
377     // Checks that all OOB watchers have the expected reporting period.
378     void CheckOobReportingPeriod(Duration expected,
379                                  SourceLocation location = SourceLocation()) {
380       MutexLock lock(&backend_metric_watcher_mu_);
381       for (const auto* watcher : orca_watchers_) {
382         EXPECT_EQ(watcher->report_interval(), expected)
383             << location.file() << ":" << location.line();
384       }
385     }
386 
NumWatchers()387     size_t NumWatchers() const {
388       size_t num_watchers;
389       absl::Notification notification;
390       work_serializer()->Run(
391           [&]() ABSL_EXCLUSIVE_LOCKS_REQUIRED(*test_->work_serializer_) {
392             num_watchers = state_tracker_.NumWatchers();
393             notification.Notify();
394           },
395           DEBUG_LOCATION);
396       notification.WaitForNotification();
397       return num_watchers;
398     }
399 
work_serializer()400     std::shared_ptr<WorkSerializer> work_serializer() const {
401       return test_->work_serializer_;
402     }
403 
state_tracker()404     ConnectivityStateTracker& state_tracker() { return state_tracker_; }
405 
406    private:
407     const std::string address_;
408     LoadBalancingPolicyTest* const test_;
409     ConnectivityStateTracker state_tracker_
410         ABSL_GUARDED_BY(*test_->work_serializer_);
411 
412     Mutex requested_connection_mu_;
413     bool requested_connection_ ABSL_GUARDED_BY(&requested_connection_mu_) =
414         false;
415 
416     Mutex backend_metric_watcher_mu_;
417     std::set<OrcaWatcher*> orca_watchers_
418         ABSL_GUARDED_BY(&backend_metric_watcher_mu_);
419   };
420 
421   // A fake helper to be passed to the LB policy.
422   class FakeHelper : public LoadBalancingPolicy::ChannelControlHelper {
423    public:
424     // Represents a state update reported by the LB policy.
425     struct StateUpdate {
426       grpc_connectivity_state state;
427       absl::Status status;
428       RefCountedPtr<LoadBalancingPolicy::SubchannelPicker> picker;
429 
ToStringStateUpdate430       std::string ToString() const {
431         return absl::StrFormat("UPDATE{state=%s, status=%s, picker=%p}",
432                                ConnectivityStateName(state), status.ToString(),
433                                picker.get());
434       }
435     };
436 
437     // Represents a re-resolution request from the LB policy.
438     struct ReresolutionRequested {
ToStringReresolutionRequested439       std::string ToString() const { return "RERESOLUTION"; }
440     };
441 
FakeHelper(LoadBalancingPolicyTest * test)442     explicit FakeHelper(LoadBalancingPolicyTest* test) : test_(test) {}
443 
QueueEmpty()444     bool QueueEmpty() {
445       MutexLock lock(&mu_);
446       return queue_.empty();
447     }
448 
449     // Called at test tear-down time to ensure that we have not left any
450     // unexpected events in the queue.
451     void ExpectQueueEmpty(SourceLocation location = SourceLocation()) {
452       MutexLock lock(&mu_);
453       EXPECT_TRUE(queue_.empty())
454           << location.file() << ":" << location.line() << "\n"
455           << QueueString();
456     }
457 
458     // Returns the next event in the queue if it is a state update.
459     // If the queue is empty or the next event is not a state update,
460     // fails the test and returns nullopt without removing anything from
461     // the queue.
462     absl::optional<StateUpdate> GetNextStateUpdate(
463         SourceLocation location = SourceLocation()) {
464       MutexLock lock(&mu_);
465       EXPECT_FALSE(queue_.empty()) << location.file() << ":" << location.line();
466       if (queue_.empty()) return absl::nullopt;
467       Event& event = queue_.front();
468       auto* update = absl::get_if<StateUpdate>(&event);
469       EXPECT_NE(update, nullptr)
470           << "unexpected event " << EventString(event) << " at "
471           << location.file() << ":" << location.line();
472       if (update == nullptr) return absl::nullopt;
473       StateUpdate result = std::move(*update);
474       gpr_log(GPR_INFO, "dequeued next state update: %s",
475               result.ToString().c_str());
476       queue_.pop_front();
477       return std::move(result);
478     }
479 
480     // Returns the next event in the queue if it is a re-resolution.
481     // If the queue is empty or the next event is not a re-resolution,
482     // fails the test and returns nullopt without removing anything
483     // from the queue.
484     absl::optional<ReresolutionRequested> GetNextReresolution(
485         SourceLocation location = SourceLocation()) {
486       MutexLock lock(&mu_);
487       EXPECT_FALSE(queue_.empty()) << location.file() << ":" << location.line();
488       if (queue_.empty()) return absl::nullopt;
489       Event& event = queue_.front();
490       auto* reresolution = absl::get_if<ReresolutionRequested>(&event);
491       EXPECT_NE(reresolution, nullptr)
492           << "unexpected event " << EventString(event) << " at "
493           << location.file() << ":" << location.line();
494       if (reresolution == nullptr) return absl::nullopt;
495       ReresolutionRequested result = *reresolution;
496       queue_.pop_front();
497       return result;
498     }
499 
500    private:
501     // A wrapper for a picker that hops into the WorkSerializer to
502     // release the ref to the picker.
503     class PickerWrapper : public LoadBalancingPolicy::SubchannelPicker {
504      public:
PickerWrapper(LoadBalancingPolicyTest * test,RefCountedPtr<LoadBalancingPolicy::SubchannelPicker> picker)505       PickerWrapper(LoadBalancingPolicyTest* test,
506                     RefCountedPtr<LoadBalancingPolicy::SubchannelPicker> picker)
507           : test_(test), picker_(std::move(picker)) {
508         gpr_log(GPR_INFO, "creating wrapper %p for picker %p", this,
509                 picker_.get());
510       }
511 
Orphaned()512       void Orphaned() override {
513         absl::Notification notification;
514         ExecCtx exec_ctx;
515         test_->work_serializer_->Run(
516             [notification = &notification,
517              picker = std::move(picker_)]() mutable {
518               picker.reset();
519               notification->Notify();
520             },
521             DEBUG_LOCATION);
522         notification.WaitForNotification();
523       }
524 
Pick(LoadBalancingPolicy::PickArgs args)525       LoadBalancingPolicy::PickResult Pick(
526           LoadBalancingPolicy::PickArgs args) override {
527         return picker_->Pick(args);
528       }
529 
530      private:
531       LoadBalancingPolicyTest* const test_;
532       RefCountedPtr<LoadBalancingPolicy::SubchannelPicker> picker_;
533     };
534 
535     // Represents an event reported by the LB policy.
536     using Event = absl::variant<StateUpdate, ReresolutionRequested>;
537 
538     // Returns a human-readable representation of an event.
EventString(const Event & event)539     static std::string EventString(const Event& event) {
540       return Match(
541           event, [](const StateUpdate& update) { return update.ToString(); },
542           [](const ReresolutionRequested& reresolution) {
543             return reresolution.ToString();
544           });
545     }
546 
QueueString()547     std::string QueueString() const ABSL_EXCLUSIVE_LOCKS_REQUIRED(&mu_) {
548       std::vector<std::string> parts = {"Queue:"};
549       for (const Event& event : queue_) {
550         parts.push_back(EventString(event));
551       }
552       return absl::StrJoin(parts, "\n  ");
553     }
554 
CreateSubchannel(const grpc_resolved_address & address,const ChannelArgs &,const ChannelArgs & args)555     RefCountedPtr<SubchannelInterface> CreateSubchannel(
556         const grpc_resolved_address& address,
557         const ChannelArgs& /*per_address_args*/,
558         const ChannelArgs& args) override {
559       // TODO(roth): Need to use per_address_args here.
560       SubchannelKey key(
561           address, args.RemoveAllKeysWithPrefix(GRPC_ARG_NO_SUBCHANNEL_PREFIX));
562       auto it = test_->subchannel_pool_.find(key);
563       if (it == test_->subchannel_pool_.end()) {
564         auto address_uri = grpc_sockaddr_to_uri(&address);
565         GPR_ASSERT(address_uri.ok());
566         it = test_->subchannel_pool_
567                  .emplace(std::piecewise_construct, std::forward_as_tuple(key),
568                           std::forward_as_tuple(std::move(*address_uri), test_))
569                  .first;
570       }
571       return it->second.CreateSubchannel();
572     }
573 
UpdateState(grpc_connectivity_state state,const absl::Status & status,RefCountedPtr<LoadBalancingPolicy::SubchannelPicker> picker)574     void UpdateState(
575         grpc_connectivity_state state, const absl::Status& status,
576         RefCountedPtr<LoadBalancingPolicy::SubchannelPicker> picker) override {
577       MutexLock lock(&mu_);
578       StateUpdate update{
579           state, status,
580           MakeRefCounted<PickerWrapper>(test_, std::move(picker))};
581       gpr_log(GPR_INFO, "enqueuing state update from LB policy: %s",
582               update.ToString().c_str());
583       queue_.push_back(std::move(update));
584     }
585 
RequestReresolution()586     void RequestReresolution() override {
587       MutexLock lock(&mu_);
588       queue_.push_back(ReresolutionRequested());
589     }
590 
GetTarget()591     absl::string_view GetTarget() override { return test_->target_; }
592 
GetAuthority()593     absl::string_view GetAuthority() override { return test_->authority_; }
594 
GetChannelCredentials()595     RefCountedPtr<grpc_channel_credentials> GetChannelCredentials() override {
596       return nullptr;
597     }
598 
GetUnsafeChannelCredentials()599     RefCountedPtr<grpc_channel_credentials> GetUnsafeChannelCredentials()
600         override {
601       return nullptr;
602     }
603 
GetEventEngine()604     grpc_event_engine::experimental::EventEngine* GetEventEngine() override {
605       return test_->fuzzing_ee_.get();
606     }
607 
GetStatsPluginGroup()608     GlobalStatsPluginRegistry::StatsPluginGroup& GetStatsPluginGroup()
609         override {
610       return test_->stats_plugin_group_;
611     }
612 
AddTraceEvent(TraceSeverity,absl::string_view)613     void AddTraceEvent(TraceSeverity, absl::string_view) override {}
614 
615     LoadBalancingPolicyTest* test_;
616 
617     Mutex mu_;
618     std::deque<Event> queue_ ABSL_GUARDED_BY(&mu_);
619   };
620 
621   // A fake MetadataInterface implementation, for use in PickArgs.
622   class FakeMetadata : public LoadBalancingPolicy::MetadataInterface {
623    public:
FakeMetadata(std::map<std::string,std::string> metadata)624     explicit FakeMetadata(std::map<std::string, std::string> metadata)
625         : metadata_(std::move(metadata)) {}
626 
metadata()627     const std::map<std::string, std::string>& metadata() const {
628       return metadata_;
629     }
630 
631    private:
Add(absl::string_view key,absl::string_view value)632     void Add(absl::string_view key, absl::string_view value) override {
633       metadata_[std::string(key)] = std::string(value);
634     }
635 
TestOnlyCopyToVector()636     std::vector<std::pair<std::string, std::string>> TestOnlyCopyToVector()
637         override {
638       return {};  // Not used.
639     }
640 
Lookup(absl::string_view key,std::string *)641     absl::optional<absl::string_view> Lookup(
642         absl::string_view key, std::string* /*buffer*/) const override {
643       auto it = metadata_.find(std::string(key));
644       if (it == metadata_.end()) return absl::nullopt;
645       return it->second;
646     }
647 
648     std::map<std::string, std::string> metadata_;
649   };
650 
651   // A fake CallState implementation, for use in PickArgs.
652   class FakeCallState : public ClientChannelLbCallState {
653    public:
FakeCallState(const CallAttributes & attributes)654     explicit FakeCallState(const CallAttributes& attributes) {
655       for (const auto& attribute : attributes) {
656         attributes_.emplace(attribute->type(), attribute);
657       }
658     }
659 
~FakeCallState()660     ~FakeCallState() override {
661       for (void* allocation : allocations_) {
662         gpr_free(allocation);
663       }
664     }
665 
666    private:
Alloc(size_t size)667     void* Alloc(size_t size) override {
668       void* allocation = gpr_malloc(size);
669       allocations_.push_back(allocation);
670       return allocation;
671     }
672 
GetCallAttribute(UniqueTypeName type)673     ServiceConfigCallData::CallAttributeInterface* GetCallAttribute(
674         UniqueTypeName type) const override {
675       auto it = attributes_.find(type);
676       if (it != attributes_.end()) {
677         return it->second;
678       }
679       return nullptr;
680     }
681 
GetCallAttemptTracer()682     ClientCallTracer::CallAttemptTracer* GetCallAttemptTracer() const override {
683       return nullptr;
684     }
685 
686     std::vector<void*> allocations_;
687     std::map<UniqueTypeName, ServiceConfigCallData::CallAttributeInterface*>
688         attributes_;
689   };
690 
691   // A fake BackendMetricAccessor implementation, for passing to
692   // SubchannelCallTrackerInterface::Finish().
693   class FakeBackendMetricAccessor
694       : public LoadBalancingPolicy::BackendMetricAccessor {
695    public:
FakeBackendMetricAccessor(absl::optional<BackendMetricData> backend_metric_data)696     explicit FakeBackendMetricAccessor(
697         absl::optional<BackendMetricData> backend_metric_data)
698         : backend_metric_data_(std::move(backend_metric_data)) {}
699 
GetBackendMetricData()700     const BackendMetricData* GetBackendMetricData() override {
701       if (backend_metric_data_.has_value()) return &*backend_metric_data_;
702       return nullptr;
703     }
704 
705    private:
706     const absl::optional<BackendMetricData> backend_metric_data_;
707   };
708 
709   explicit LoadBalancingPolicyTest(absl::string_view lb_policy_name,
710                                    ChannelArgs channel_args = ChannelArgs())
lb_policy_name_(lb_policy_name)711       : lb_policy_name_(lb_policy_name),
712         channel_args_(std::move(channel_args)) {}
713 
SetUp()714   void SetUp() override {
715     // Order is important here: Fuzzing EE needs to be created before
716     // grpc_init(), and the POSIX EE (which is used by the WorkSerializer)
717     // needs to be created after grpc_init().
718     fuzzing_ee_ =
719         std::make_shared<grpc_event_engine::experimental::FuzzingEventEngine>(
720             grpc_event_engine::experimental::FuzzingEventEngine::Options(),
721             fuzzing_event_engine::Actions());
722     grpc_init();
723     event_engine_ = grpc_event_engine::experimental::GetDefaultEventEngine();
724     work_serializer_ = std::make_shared<WorkSerializer>(event_engine_);
725     auto helper = std::make_unique<FakeHelper>(this);
726     helper_ = helper.get();
727     LoadBalancingPolicy::Args args = {work_serializer_, std::move(helper),
728                                       channel_args_};
729     lb_policy_ =
730         CoreConfiguration::Get().lb_policy_registry().CreateLoadBalancingPolicy(
731             lb_policy_name_, std::move(args));
732     GPR_ASSERT(lb_policy_ != nullptr);
733   }
734 
TearDown()735   void TearDown() override {
736     ExecCtx exec_ctx;
737     fuzzing_ee_->FuzzingDone();
738     // Make sure pickers (and transitively, subchannels) are unreffed before
739     // destroying the fixture.
740     WaitForWorkSerializerToFlush();
741     work_serializer_.reset();
742     exec_ctx.Flush();
743     // Note: Can't safely trigger this from inside the FakeHelper dtor,
744     // because if there is a picker in the queue that is holding a ref
745     // to the LB policy, that will prevent the LB policy from being
746     // destroyed, and therefore the FakeHelper will not be destroyed.
747     // (This will cause an ASAN failure, but it will not display the
748     // queued events, so the failure will be harder to diagnose.)
749     helper_->ExpectQueueEmpty();
750     lb_policy_.reset();
751     fuzzing_ee_->TickUntilIdle();
752     grpc_event_engine::experimental::WaitForSingleOwner(
753         std::move(event_engine_));
754     event_engine_.reset();
755     grpc_shutdown_blocking();
756     fuzzing_ee_.reset();
757   }
758 
lb_policy()759   LoadBalancingPolicy* lb_policy() const {
760     GPR_ASSERT(lb_policy_ != nullptr);
761     return lb_policy_.get();
762   }
763 
764   // Creates an LB policy config from json.
765   static RefCountedPtr<LoadBalancingPolicy::Config> MakeConfig(
766       const Json& json, SourceLocation location = SourceLocation()) {
767     auto status_or_config =
768         CoreConfiguration::Get().lb_policy_registry().ParseLoadBalancingConfig(
769             json);
770     EXPECT_TRUE(status_or_config.ok())
771         << status_or_config.status() << "\n"
772         << location.file() << ":" << location.line();
773     return status_or_config.value();
774   }
775 
776   // Converts an address URI into a grpc_resolved_address.
MakeAddress(absl::string_view address_uri)777   static grpc_resolved_address MakeAddress(absl::string_view address_uri) {
778     auto uri = URI::Parse(address_uri);
779     GPR_ASSERT(uri.ok());
780     grpc_resolved_address address;
781     GPR_ASSERT(grpc_parse_uri(*uri, &address));
782     return address;
783   }
784 
MakeAddressList(absl::Span<const absl::string_view> addresses)785   std::vector<grpc_resolved_address> MakeAddressList(
786       absl::Span<const absl::string_view> addresses) {
787     std::vector<grpc_resolved_address> addrs;
788     for (const absl::string_view& address : addresses) {
789       addrs.emplace_back(MakeAddress(address));
790     }
791     return addrs;
792   }
793 
794   EndpointAddresses MakeEndpointAddresses(
795       absl::Span<const absl::string_view> addresses,
796       const ChannelArgs& args = ChannelArgs()) {
797     return EndpointAddresses(MakeAddressList(addresses), args);
798   }
799 
800   // Constructs an update containing a list of endpoints.
801   LoadBalancingPolicy::UpdateArgs BuildUpdate(
802       absl::Span<const EndpointAddresses> endpoints,
803       RefCountedPtr<LoadBalancingPolicy::Config> config,
804       ChannelArgs args = ChannelArgs()) {
805     LoadBalancingPolicy::UpdateArgs update;
806     update.addresses = std::make_shared<EndpointAddressesListIterator>(
807         EndpointAddressesList(endpoints.begin(), endpoints.end()));
808     update.config = std::move(config);
809     update.args = std::move(args);
810     return update;
811   }
812 
MakeEndpointAddressesListFromAddressList(absl::Span<const absl::string_view> addresses)813   std::vector<EndpointAddresses> MakeEndpointAddressesListFromAddressList(
814       absl::Span<const absl::string_view> addresses) {
815     std::vector<EndpointAddresses> endpoints;
816     for (const absl::string_view address : addresses) {
817       endpoints.emplace_back(MakeAddress(address), ChannelArgs());
818     }
819     return endpoints;
820   }
821 
822   // Convenient overload that takes a flat address list.
823   LoadBalancingPolicy::UpdateArgs BuildUpdate(
824       absl::Span<const absl::string_view> addresses,
825       RefCountedPtr<LoadBalancingPolicy::Config> config,
826       ChannelArgs args = ChannelArgs()) {
827     return BuildUpdate(MakeEndpointAddressesListFromAddressList(addresses),
828                        std::move(config), std::move(args));
829   }
830 
831   // Applies the update on the LB policy.
ApplyUpdate(LoadBalancingPolicy::UpdateArgs update_args,LoadBalancingPolicy * lb_policy)832   absl::Status ApplyUpdate(LoadBalancingPolicy::UpdateArgs update_args,
833                            LoadBalancingPolicy* lb_policy) {
834     ExecCtx exec_ctx;
835     absl::Status status;
836     // When the LB policy gets the update, it will create new
837     // subchannels, and it will register connectivity state watchers and
838     // optionally health watchers for each one.  We don't want to return
839     // until all the initial notifications for all of those watchers
840     // have been delivered to the LB policy.
841     absl::Notification notification;
842     work_serializer_->Run(
843         [&]() {
844           status = lb_policy->UpdateLocked(std::move(update_args));
845           // UpdateLocked() enqueued the initial connectivity state
846           // notifications for the subchannels, so we add another
847           // callback to the queue to be executed after those initial
848           // state notifications have been delivered.
849           gpr_log(GPR_INFO,
850                   "Applied update, waiting for initial connectivity state "
851                   "notifications");
852           work_serializer_->Run(
853               [&]() {
854                 gpr_log(GPR_INFO,
855                         "Initial connectivity state notifications delivered; "
856                         "waiting for health notifications");
857                 // Now that the initial state notifications have been
858                 // delivered, the queue will contain the health watch
859                 // notifications for any subchannels in state READY,
860                 // so we add another callback to the queue to be
861                 // executed after those health watch notifications have
862                 // been delivered.
863                 work_serializer_->Run([&]() { notification.Notify(); },
864                                       DEBUG_LOCATION);
865               },
866               DEBUG_LOCATION);
867         },
868         DEBUG_LOCATION);
869     notification.WaitForNotification();
870     gpr_log(GPR_INFO, "health notifications delivered");
871     return status;
872   }
873 
874   // Invoke ExitIdle on the LB policy
ExitIdle()875   void ExitIdle() {
876     ExecCtx exec_ctx;
877     absl::Notification notification;
878     // Note: ExitIdle() will enqueue a bunch of connectivity state
879     // notifications on the WorkSerializer, and we want to wait until
880     // those are delivered to the LB policy.
881     work_serializer_->Run(
882         [&]() {
883           lb_policy_->ExitIdleLocked();
884           work_serializer_->Run([&]() { notification.Notify(); },
885                                 DEBUG_LOCATION);
886         },
887         DEBUG_LOCATION);
888     notification.WaitForNotification();
889   }
890 
891   void ExpectQueueEmpty(SourceLocation location = SourceLocation()) {
892     helper_->ExpectQueueEmpty(location);
893   }
894 
895   // Keeps reading state updates until continue_predicate() returns false.
896   // Returns false if the helper reports no events or if the event is
897   // not a state update; otherwise (if continue_predicate() tells us to
898   // stop) returns true.
899   bool WaitForStateUpdate(
900       std::function<bool(FakeHelper::StateUpdate update)> continue_predicate,
901       SourceLocation location = SourceLocation()) {
902     gpr_log(GPR_INFO, "==> WaitForStateUpdate()");
903     while (true) {
904       auto update = helper_->GetNextStateUpdate(location);
905       if (!update.has_value()) {
906         gpr_log(GPR_INFO, "WaitForStateUpdate() returning false");
907         return false;
908       }
909       if (!continue_predicate(std::move(*update))) {
910         gpr_log(GPR_INFO, "WaitForStateUpdate() returning true");
911         return true;
912       }
913     }
914   }
915 
916   void ExpectReresolutionRequest(SourceLocation location = SourceLocation()) {
917     ASSERT_TRUE(helper_->GetNextReresolution(location))
918         << location.file() << ":" << location.line();
919   }
920 
921   // Expects that the LB policy has reported the specified connectivity
922   // state to helper_.  Returns the picker from the state update.
923   RefCountedPtr<LoadBalancingPolicy::SubchannelPicker> ExpectState(
924       grpc_connectivity_state expected_state,
925       absl::Status expected_status = absl::OkStatus(),
926       SourceLocation location = SourceLocation()) {
927     auto update = helper_->GetNextStateUpdate(location);
928     if (!update.has_value()) return nullptr;
929     EXPECT_EQ(update->state, expected_state)
930         << "got " << ConnectivityStateName(update->state) << ", expected "
931         << ConnectivityStateName(expected_state) << "\n"
932         << "at " << location.file() << ":" << location.line();
933     EXPECT_EQ(update->status, expected_status)
934         << update->status << "\n"
935         << location.file() << ":" << location.line();
936     EXPECT_NE(update->picker, nullptr)
937         << location.file() << ":" << location.line();
938     return std::move(update->picker);
939   }
940 
941   // Waits for the LB policy to get connected, then returns the final
942   // picker.  There can be any number of CONNECTING updates, each of
943   // which must return a picker that queues picks, followed by one
944   // update for state READY, whose picker is returned.
945   RefCountedPtr<LoadBalancingPolicy::SubchannelPicker> WaitForConnected(
946       SourceLocation location = SourceLocation()) {
947     gpr_log(GPR_INFO, "==> WaitForConnected()");
948     RefCountedPtr<LoadBalancingPolicy::SubchannelPicker> final_picker;
949     WaitForStateUpdate(
950         [&](FakeHelper::StateUpdate update) {
951           if (update.state == GRPC_CHANNEL_CONNECTING) {
952             EXPECT_TRUE(update.status.ok())
953                 << update.status << " at " << location.file() << ":"
954                 << location.line();
955             ExpectPickQueued(update.picker.get(), {}, location);
956             return true;  // Keep going.
957           }
958           EXPECT_EQ(update.state, GRPC_CHANNEL_READY)
959               << ConnectivityStateName(update.state) << " at "
960               << location.file() << ":" << location.line();
961           final_picker = std::move(update.picker);
962           return false;  // Stop.
963         },
964         location);
965     return final_picker;
966   }
967 
968   void ExpectTransientFailureUpdate(
969       absl::Status expected_status,
970       SourceLocation location = SourceLocation()) {
971     auto picker =
972         ExpectState(GRPC_CHANNEL_TRANSIENT_FAILURE, expected_status, location);
973     ASSERT_NE(picker, nullptr);
974     ExpectPickFail(
975         picker.get(),
976         [&](const absl::Status& status) {
977           EXPECT_EQ(status, expected_status)
978               << location.file() << ":" << location.line();
979         },
980         location);
981   }
982 
983   // Waits for the LB policy to fail a connection attempt.  There can be
984   // any number of CONNECTING updates, each of which must return a picker
985   // that queues picks, followed by one update for state TRANSIENT_FAILURE,
986   // whose status is passed to check_status() and whose picker must fail
987   // picks with a status that is passed to check_status().
988   // Returns true if the reported states match expectations.
989   bool WaitForConnectionFailed(
990       std::function<void(const absl::Status&)> check_status,
991       SourceLocation location = SourceLocation()) {
992     bool retval = false;
993     WaitForStateUpdate(
994         [&](FakeHelper::StateUpdate update) {
995           if (update.state == GRPC_CHANNEL_CONNECTING) {
996             EXPECT_TRUE(update.status.ok())
997                 << update.status << " at " << location.file() << ":"
998                 << location.line();
999             ExpectPickQueued(update.picker.get(), {}, location);
1000             return true;  // Keep going.
1001           }
1002           EXPECT_EQ(update.state, GRPC_CHANNEL_TRANSIENT_FAILURE)
1003               << ConnectivityStateName(update.state) << " at "
1004               << location.file() << ":" << location.line();
1005           check_status(update.status);
1006           ExpectPickFail(update.picker.get(), check_status, location);
1007           retval = update.state == GRPC_CHANNEL_TRANSIENT_FAILURE;
1008           return false;  // Stop.
1009         },
1010         location);
1011     return retval;
1012   }
1013 
1014   // Waits for the round_robin policy to start using an updated address list.
1015   // There can be any number of READY updates where the picker is still using
1016   // the old list followed by one READY update where the picker is using the
1017   // new list.  Returns a picker if the reported states match expectations.
1018   RefCountedPtr<LoadBalancingPolicy::SubchannelPicker>
1019   WaitForRoundRobinListChange(absl::Span<const absl::string_view> old_addresses,
1020                               absl::Span<const absl::string_view> new_addresses,
1021                               const CallAttributes& call_attributes = {},
1022                               size_t num_iterations = 3,
1023                               SourceLocation location = SourceLocation()) {
1024     gpr_log(GPR_INFO, "Waiting for expected RR addresses...");
1025     RefCountedPtr<LoadBalancingPolicy::SubchannelPicker> retval;
1026     size_t num_picks =
1027         std::max(new_addresses.size(), old_addresses.size()) * num_iterations;
1028     WaitForStateUpdate(
1029         [&](FakeHelper::StateUpdate update) {
1030           EXPECT_EQ(update.state, GRPC_CHANNEL_READY)
1031               << location.file() << ":" << location.line();
1032           if (update.state != GRPC_CHANNEL_READY) return false;
1033           // Get enough picks to round-robin num_iterations times across all
1034           // expected addresses.
1035           auto picks = GetCompletePicks(update.picker.get(), num_picks,
1036                                         call_attributes, nullptr, location);
1037           EXPECT_TRUE(picks.has_value())
1038               << location.file() << ":" << location.line();
1039           if (!picks.has_value()) return false;
1040           gpr_log(GPR_INFO, "PICKS: %s", absl::StrJoin(*picks, " ").c_str());
1041           // If the picks still match the old list, then keep going.
1042           if (PicksAreRoundRobin(old_addresses, *picks)) return true;
1043           // Otherwise, the picks should match the new list.
1044           bool matches = PicksAreRoundRobin(new_addresses, *picks);
1045           EXPECT_TRUE(matches)
1046               << "Expected: " << absl::StrJoin(new_addresses, ", ")
1047               << "\nActual: " << absl::StrJoin(*picks, ", ") << "\nat "
1048               << location.file() << ":" << location.line();
1049           if (matches) {
1050             retval = std::move(update.picker);
1051           }
1052           return false;  // Stop.
1053         },
1054         location);
1055     gpr_log(GPR_INFO, "done waiting for expected RR addresses");
1056     return retval;
1057   }
1058 
1059   // Expects a state update for the specified state and status, and then
1060   // expects the resulting picker to queue picks.
1061   bool ExpectStateAndQueuingPicker(
1062       grpc_connectivity_state expected_state,
1063       absl::Status expected_status = absl::OkStatus(),
1064       SourceLocation location = SourceLocation()) {
1065     auto picker = ExpectState(expected_state, expected_status, location);
1066     return ExpectPickQueued(picker.get(), {}, location);
1067   }
1068 
1069   // Convenient frontend to ExpectStateAndQueuingPicker() for CONNECTING.
1070   bool ExpectConnectingUpdate(SourceLocation location = SourceLocation()) {
1071     return ExpectStateAndQueuingPicker(GRPC_CHANNEL_CONNECTING,
1072                                        absl::OkStatus(), location);
1073   }
1074 
1075   static std::unique_ptr<LoadBalancingPolicy::MetadataInterface> MakeMetadata(
1076       std::map<std::string, std::string> init = {}) {
1077     return std::make_unique<FakeMetadata>(init);
1078   }
1079 
1080   // Does a pick and returns the result.
1081   LoadBalancingPolicy::PickResult DoPick(
1082       LoadBalancingPolicy::SubchannelPicker* picker,
1083       const CallAttributes& call_attributes = {}) {
1084     ExecCtx exec_ctx;
1085     FakeMetadata metadata({});
1086     FakeCallState call_state(call_attributes);
1087     return picker->Pick({"/service/method", &metadata, &call_state});
1088   }
1089 
1090   // Requests a pick on picker and expects a Queue result.
1091   bool ExpectPickQueued(LoadBalancingPolicy::SubchannelPicker* picker,
1092                         const CallAttributes call_attributes = {},
1093                         SourceLocation location = SourceLocation()) {
1094     EXPECT_NE(picker, nullptr) << location.file() << ":" << location.line();
1095     if (picker == nullptr) return false;
1096     auto pick_result = DoPick(picker, call_attributes);
1097     EXPECT_TRUE(absl::holds_alternative<LoadBalancingPolicy::PickResult::Queue>(
1098         pick_result.result))
1099         << PickResultString(pick_result) << "\nat " << location.file() << ":"
1100         << location.line();
1101     return absl::holds_alternative<LoadBalancingPolicy::PickResult::Queue>(
1102         pick_result.result);
1103   }
1104 
1105   // Requests a pick on picker and expects a Complete result.
1106   // The address of the resulting subchannel is returned, or nullopt if
1107   // the result was something other than Complete.
1108   // If the complete pick includes a SubchannelCallTrackerInterface, then if
1109   // subchannel_call_tracker is non-null, it will be set to point to the
1110   // call tracker; otherwise, the call tracker will be invoked
1111   // automatically to represent a complete call with no backend metric data.
1112   absl::optional<std::string> ExpectPickComplete(
1113       LoadBalancingPolicy::SubchannelPicker* picker,
1114       const CallAttributes& call_attributes = {},
1115       std::unique_ptr<LoadBalancingPolicy::SubchannelCallTrackerInterface>*
1116           subchannel_call_tracker = nullptr,
1117       SubchannelState::FakeSubchannel** picked_subchannel = nullptr,
1118       SourceLocation location = SourceLocation()) {
1119     EXPECT_NE(picker, nullptr);
1120     if (picker == nullptr) {
1121       return absl::nullopt;
1122     }
1123     auto pick_result = DoPick(picker, call_attributes);
1124     auto* complete = absl::get_if<LoadBalancingPolicy::PickResult::Complete>(
1125         &pick_result.result);
1126     EXPECT_NE(complete, nullptr) << PickResultString(pick_result) << " at "
1127                                  << location.file() << ":" << location.line();
1128     if (complete == nullptr) return absl::nullopt;
1129     auto* subchannel = static_cast<SubchannelState::FakeSubchannel*>(
1130         complete->subchannel.get());
1131     if (picked_subchannel != nullptr) *picked_subchannel = subchannel;
1132     std::string address = subchannel->state()->address();
1133     if (complete->subchannel_call_tracker != nullptr) {
1134       if (subchannel_call_tracker != nullptr) {
1135         *subchannel_call_tracker = std::move(complete->subchannel_call_tracker);
1136       } else {
1137         ReportCompletionToCallTracker(
1138             std::move(complete->subchannel_call_tracker), address);
1139       }
1140     }
1141     return address;
1142   }
1143 
1144   void ReportCompletionToCallTracker(
1145       std::unique_ptr<LoadBalancingPolicy::SubchannelCallTrackerInterface>
1146           subchannel_call_tracker,
1147       absl::string_view address, absl::Status status = absl::OkStatus()) {
1148     subchannel_call_tracker->Start();
1149     FakeMetadata metadata({});
1150     FakeBackendMetricAccessor backend_metric_accessor({});
1151     LoadBalancingPolicy::SubchannelCallTrackerInterface::FinishArgs args = {
1152         address, status, &metadata, &backend_metric_accessor};
1153     subchannel_call_tracker->Finish(args);
1154   }
1155 
1156   // Gets num_picks complete picks from picker and returns the resulting
1157   // list of addresses, or nullopt if a non-complete pick was returned.
1158   absl::optional<std::vector<std::string>> GetCompletePicks(
1159       LoadBalancingPolicy::SubchannelPicker* picker, size_t num_picks,
1160       const CallAttributes& call_attributes = {},
1161       std::vector<
1162           std::unique_ptr<LoadBalancingPolicy::SubchannelCallTrackerInterface>>*
1163           subchannel_call_trackers = nullptr,
1164       SourceLocation location = SourceLocation()) {
1165     EXPECT_NE(picker, nullptr);
1166     if (picker == nullptr) {
1167       return absl::nullopt;
1168     }
1169     std::vector<std::string> results;
1170     for (size_t i = 0; i < num_picks; ++i) {
1171       std::unique_ptr<LoadBalancingPolicy::SubchannelCallTrackerInterface>
1172           subchannel_call_tracker;
1173       auto address = ExpectPickComplete(picker, call_attributes,
1174                                         subchannel_call_trackers == nullptr
1175                                             ? nullptr
1176                                             : &subchannel_call_tracker,
1177                                         nullptr, location);
1178       if (!address.has_value()) return absl::nullopt;
1179       results.emplace_back(std::move(*address));
1180       if (subchannel_call_trackers != nullptr) {
1181         subchannel_call_trackers->emplace_back(
1182             std::move(subchannel_call_tracker));
1183       }
1184     }
1185     return results;
1186   }
1187 
1188   // Returns true if the list of actual pick result addresses matches the
1189   // list of expected addresses for round_robin.  Note that the actual
1190   // addresses may start anywhere in the list of expected addresses but
1191   // must then continue in round-robin fashion, with wrap-around.
PicksAreRoundRobin(absl::Span<const absl::string_view> expected,absl::Span<const std::string> actual)1192   bool PicksAreRoundRobin(absl::Span<const absl::string_view> expected,
1193                           absl::Span<const std::string> actual) {
1194     absl::optional<size_t> expected_index;
1195     for (const auto& address : actual) {
1196       auto it = std::find(expected.begin(), expected.end(), address);
1197       if (it == expected.end()) return false;
1198       size_t index = it - expected.begin();
1199       if (expected_index.has_value() && index != *expected_index) return false;
1200       expected_index = (index + 1) % expected.size();
1201     }
1202     return true;
1203   }
1204 
1205   // Checks that the picker has round-robin behavior over the specified
1206   // set of addresses.
1207   void ExpectRoundRobinPicks(LoadBalancingPolicy::SubchannelPicker* picker,
1208                              absl::Span<const absl::string_view> addresses,
1209                              const CallAttributes& call_attributes = {},
1210                              size_t num_iterations = 3,
1211                              SourceLocation location = SourceLocation()) {
1212     auto picks = GetCompletePicks(picker, num_iterations * addresses.size(),
1213                                   call_attributes, nullptr, location);
1214     ASSERT_TRUE(picks.has_value()) << location.file() << ":" << location.line();
1215     EXPECT_TRUE(PicksAreRoundRobin(addresses, *picks))
1216         << "  Actual: " << absl::StrJoin(*picks, ", ")
1217         << "\n  Expected: " << absl::StrJoin(addresses, ", ") << "\n"
1218         << location.file() << ":" << location.line();
1219   }
1220 
1221   // Expect startup with RR with a set of addresses.
1222   RefCountedPtr<LoadBalancingPolicy::SubchannelPicker> ExpectRoundRobinStartup(
1223       absl::Span<const EndpointAddresses> endpoints,
1224       SourceLocation location = SourceLocation()) {
1225     GPR_ASSERT(!endpoints.empty());
1226     // There should be a subchannel for every address.
1227     // We will wind up connecting to the first address for every endpoint.
1228     std::vector<std::vector<SubchannelState*>> endpoint_subchannels;
1229     endpoint_subchannels.reserve(endpoints.size());
1230     std::vector<std::string> chosen_addresses_storage;
1231     chosen_addresses_storage.reserve(endpoints.size());
1232     std::vector<absl::string_view> chosen_addresses;
1233     chosen_addresses.reserve(endpoints.size());
1234     for (const EndpointAddresses& endpoint : endpoints) {
1235       endpoint_subchannels.emplace_back();
1236       endpoint_subchannels.back().reserve(endpoint.addresses().size());
1237       for (size_t i = 0; i < endpoint.addresses().size(); ++i) {
1238         const grpc_resolved_address& address = endpoint.addresses()[i];
1239         std::string address_str = grpc_sockaddr_to_uri(&address).value();
1240         auto* subchannel = FindSubchannel(address_str);
1241         EXPECT_NE(subchannel, nullptr)
1242             << address_str << "\n"
1243             << location.file() << ":" << location.line();
1244         if (subchannel == nullptr) return nullptr;
1245         endpoint_subchannels.back().push_back(subchannel);
1246         if (i == 0) {
1247           chosen_addresses_storage.emplace_back(std::move(address_str));
1248           chosen_addresses.emplace_back(chosen_addresses_storage.back());
1249         }
1250       }
1251     }
1252     // We should request a connection to the first address of each endpoint,
1253     // and not to any of the subsequent addresses.
1254     for (const auto& subchannels : endpoint_subchannels) {
1255       EXPECT_TRUE(subchannels[0]->ConnectionRequested())
1256           << location.file() << ":" << location.line();
1257       for (size_t i = 1; i < subchannels.size(); ++i) {
1258         EXPECT_FALSE(subchannels[i]->ConnectionRequested())
1259             << "i=" << i << "\n"
1260             << location.file() << ":" << location.line();
1261       }
1262     }
1263     // The subchannels that we've asked to connect should report
1264     // CONNECTING state.
1265     for (size_t i = 0; i < endpoint_subchannels.size(); ++i) {
1266       endpoint_subchannels[i][0]->SetConnectivityState(GRPC_CHANNEL_CONNECTING);
1267       if (i == 0) ExpectConnectingUpdate(location);
1268     }
1269     // The connection attempts should succeed.
1270     RefCountedPtr<LoadBalancingPolicy::SubchannelPicker> picker;
1271     for (size_t i = 0; i < endpoint_subchannels.size(); ++i) {
1272       endpoint_subchannels[i][0]->SetConnectivityState(GRPC_CHANNEL_READY);
1273       if (i == 0) {
1274         // When the first subchannel becomes READY, accept any number of
1275         // CONNECTING updates with a picker that queues followed by a READY
1276         // update with a picker that repeatedly returns only the first address.
1277         picker = WaitForConnected(location);
1278         ExpectRoundRobinPicks(picker.get(), {chosen_addresses[0]}, {}, 3,
1279                               location);
1280       } else {
1281         // When each subsequent subchannel becomes READY, we accept any number
1282         // of READY updates where the picker returns only the previously
1283         // connected subchannel(s) followed by a READY update where the picker
1284         // returns the previously connected subchannel(s) *and* the newly
1285         // connected subchannel.
1286         picker = WaitForRoundRobinListChange(
1287             absl::MakeSpan(chosen_addresses).subspan(0, i),
1288             absl::MakeSpan(chosen_addresses).subspan(0, i + 1), {}, 3,
1289             location);
1290       }
1291     }
1292     return picker;
1293   }
1294 
1295   // A convenient override that takes a flat list of addresses, one per
1296   // endpoint.
1297   RefCountedPtr<LoadBalancingPolicy::SubchannelPicker> ExpectRoundRobinStartup(
1298       absl::Span<const absl::string_view> addresses,
1299       SourceLocation location = SourceLocation()) {
1300     return ExpectRoundRobinStartup(
1301         MakeEndpointAddressesListFromAddressList(addresses), location);
1302   }
1303 
1304   // Expects zero or more picker updates, each of which returns
1305   // round-robin picks for the specified set of addresses.
1306   RefCountedPtr<LoadBalancingPolicy::SubchannelPicker>
1307   DrainRoundRobinPickerUpdates(absl::Span<const absl::string_view> addresses,
1308                                SourceLocation location = SourceLocation()) {
1309     gpr_log(GPR_INFO, "Draining RR picker updates...");
1310     RefCountedPtr<LoadBalancingPolicy::SubchannelPicker> picker;
1311     while (!helper_->QueueEmpty()) {
1312       auto update = helper_->GetNextStateUpdate(location);
1313       EXPECT_TRUE(update.has_value())
1314           << location.file() << ":" << location.line();
1315       if (!update.has_value()) return nullptr;
1316       EXPECT_EQ(update->state, GRPC_CHANNEL_READY)
1317           << location.file() << ":" << location.line();
1318       if (update->state != GRPC_CHANNEL_READY) return nullptr;
1319       ExpectRoundRobinPicks(update->picker.get(), addresses,
1320                             /*call_attributes=*/{}, /*num_iterations=*/3,
1321                             location);
1322       picker = std::move(update->picker);
1323     }
1324     gpr_log(GPR_INFO, "Done draining RR picker updates");
1325     return picker;
1326   }
1327 
1328   // Expects zero or more CONNECTING updates.
1329   void DrainConnectingUpdates(SourceLocation location = SourceLocation()) {
1330     gpr_log(GPR_INFO, "Draining CONNECTING updates...");
1331     while (!helper_->QueueEmpty()) {
1332       ASSERT_TRUE(ExpectConnectingUpdate(location));
1333     }
1334     gpr_log(GPR_INFO, "Done draining CONNECTING updates");
1335   }
1336 
1337   // Triggers a connection failure for the current address for an
1338   // endpoint and expects a reconnection to the specified new address.
1339   void ExpectEndpointAddressChange(
1340       absl::Span<const absl::string_view> addresses, size_t current_index,
1341       size_t new_index, absl::AnyInvocable<void()> expect_after_disconnect,
1342       SourceLocation location = SourceLocation()) {
1343     gpr_log(GPR_INFO,
1344             "Expecting endpoint address change: addresses={%s}, "
1345             "current_index=%" PRIuPTR ", new_index=%" PRIuPTR,
1346             absl::StrJoin(addresses, ", ").c_str(), current_index, new_index);
1347     ASSERT_LT(current_index, addresses.size());
1348     ASSERT_LT(new_index, addresses.size());
1349     // Find all subchannels.
1350     std::vector<SubchannelState*> subchannels;
1351     subchannels.reserve(addresses.size());
1352     for (absl::string_view address : addresses) {
1353       SubchannelState* subchannel = FindSubchannel(address);
1354       ASSERT_NE(subchannel, nullptr)
1355           << "can't find subchannel for " << address << "\n"
1356           << location.file() << ":" << location.line();
1357       subchannels.push_back(subchannel);
1358     }
1359     // Cause current_address to become disconnected.
1360     subchannels[current_index]->SetConnectivityState(GRPC_CHANNEL_IDLE);
1361     ExpectReresolutionRequest(location);
1362     if (expect_after_disconnect != nullptr) expect_after_disconnect();
1363     // Attempt each address in the list until we hit the desired new address.
1364     for (size_t i = 0; i < subchannels.size(); ++i) {
1365       // A connection should be requested on the subchannel for this
1366       // index, and none of the others.
1367       for (size_t j = 0; j < addresses.size(); ++j) {
1368         EXPECT_EQ(subchannels[j]->ConnectionRequested(), j == i)
1369             << location.file() << ":" << location.line();
1370       }
1371       // Subchannel will report CONNECTING.
1372       SubchannelState* subchannel = subchannels[i];
1373       subchannel->SetConnectivityState(GRPC_CHANNEL_CONNECTING);
1374       // If this is the one we want to stick with, it will report READY.
1375       if (i == new_index) {
1376         subchannel->SetConnectivityState(GRPC_CHANNEL_READY);
1377         break;
1378       }
1379       // Otherwise, report TF.
1380       subchannel->SetConnectivityState(
1381           GRPC_CHANNEL_TRANSIENT_FAILURE,
1382           absl::UnavailableError("connection failed"));
1383       // Report IDLE to leave it in the expected state in case the test
1384       // interacts with it again.
1385       subchannel->SetConnectivityState(GRPC_CHANNEL_IDLE);
1386     }
1387     gpr_log(GPR_INFO, "Done with endpoint address change");
1388   }
1389 
1390   // Requests a picker on picker and expects a Fail result.
1391   // The failing status is passed to check_status.
1392   void ExpectPickFail(LoadBalancingPolicy::SubchannelPicker* picker,
1393                       std::function<void(const absl::Status&)> check_status,
1394                       SourceLocation location = SourceLocation()) {
1395     auto pick_result = DoPick(picker);
1396     auto* fail = absl::get_if<LoadBalancingPolicy::PickResult::Fail>(
1397         &pick_result.result);
1398     ASSERT_NE(fail, nullptr) << PickResultString(pick_result) << " at "
1399                              << location.file() << ":" << location.line();
1400     check_status(fail->status);
1401   }
1402 
1403   // Returns a human-readable string for a pick result.
PickResultString(const LoadBalancingPolicy::PickResult & result)1404   static std::string PickResultString(
1405       const LoadBalancingPolicy::PickResult& result) {
1406     return Match(
1407         result.result,
1408         [](const LoadBalancingPolicy::PickResult::Complete& complete) {
1409           auto* subchannel = static_cast<SubchannelState::FakeSubchannel*>(
1410               complete.subchannel.get());
1411           return absl::StrFormat(
1412               "COMPLETE{subchannel=%s, subchannel_call_tracker=%p}",
1413               subchannel->state()->address(),
1414               complete.subchannel_call_tracker.get());
1415         },
1416         [](const LoadBalancingPolicy::PickResult::Queue&) -> std::string {
1417           return "QUEUE{}";
1418         },
1419         [](const LoadBalancingPolicy::PickResult::Fail& fail) -> std::string {
1420           return absl::StrFormat("FAIL{%s}", fail.status.ToString());
1421         },
1422         [](const LoadBalancingPolicy::PickResult::Drop& drop) -> std::string {
1423           return absl::StrFormat("FAIL{%s}", drop.status.ToString());
1424         });
1425   }
1426 
1427   // Returns the entry in the subchannel pool, or null if not present.
1428   SubchannelState* FindSubchannel(absl::string_view address,
1429                                   const ChannelArgs& args = ChannelArgs()) {
1430     SubchannelKey key(MakeAddress(address), args);
1431     auto it = subchannel_pool_.find(key);
1432     if (it == subchannel_pool_.end()) return nullptr;
1433     return &it->second;
1434   }
1435 
1436   // Creates and returns an entry in the subchannel pool.
1437   // This can be used in cases where we want to test that a subchannel
1438   // already exists when the LB policy creates it (e.g., due to it being
1439   // created by another channel and shared via the global subchannel
1440   // pool, or by being created by another LB policy in this channel).
1441   SubchannelState* CreateSubchannel(absl::string_view address,
1442                                     const ChannelArgs& args = ChannelArgs()) {
1443     SubchannelKey key(MakeAddress(address), args);
1444     auto it = subchannel_pool_
1445                   .emplace(std::piecewise_construct, std::forward_as_tuple(key),
1446                            std::forward_as_tuple(address, this))
1447                   .first;
1448     return &it->second;
1449   }
1450 
WaitForWorkSerializerToFlush()1451   void WaitForWorkSerializerToFlush() {
1452     ExecCtx exec_ctx;
1453     gpr_log(GPR_INFO, "waiting for WorkSerializer to flush...");
1454     absl::Notification notification;
1455     work_serializer_->Run([&]() { notification.Notify(); }, DEBUG_LOCATION);
1456     notification.WaitForNotification();
1457     gpr_log(GPR_INFO, "WorkSerializer flush complete");
1458   }
1459 
IncrementTimeBy(Duration duration)1460   void IncrementTimeBy(Duration duration) {
1461     ExecCtx exec_ctx;
1462     gpr_log(GPR_INFO, "Incrementing time by %s...",
1463             duration.ToString().c_str());
1464     fuzzing_ee_->TickForDuration(duration);
1465     gpr_log(GPR_INFO, "Done incrementing time");
1466     // Flush WorkSerializer, in case the timer callback enqueued anything.
1467     WaitForWorkSerializerToFlush();
1468   }
1469 
1470   void SetExpectedTimerDuration(
1471       absl::optional<grpc_event_engine::experimental::EventEngine::Duration>
1472           duration,
1473       SourceLocation location = SourceLocation()) {
1474     if (duration.has_value()) {
1475       fuzzing_ee_->SetRunAfterDurationCallback(
1476           [expected = *duration, location = location](
1477               grpc_event_engine::experimental::EventEngine::Duration duration) {
1478             EXPECT_EQ(duration, expected)
1479                 << "Expected: " << expected.count()
1480                 << "ns\n  Actual: " << duration.count() << "ns\n"
1481                 << location.file() << ":" << location.line();
1482           });
1483     } else {
1484       fuzzing_ee_->SetRunAfterDurationCallback(nullptr);
1485     }
1486   }
1487 
1488   std::shared_ptr<grpc_event_engine::experimental::FuzzingEventEngine>
1489       fuzzing_ee_;
1490   // TODO(ctiller): this is a normal event engine, yet it gets its time measure
1491   // from fuzzing_ee_ -- results are likely to be a little funky, but seem to do
1492   // well enough for the tests we have today.
1493   // We should transition everything here to just use fuzzing_ee_, but that
1494   // needs some thought on how to Tick() at appropriate times, as there are
1495   // Notification objects buried everywhere in this code, and
1496   // WaitForNotification is deeply incompatible with a single threaded event
1497   // engine that doesn't run callbacks until its public Tick method is called.
1498   std::shared_ptr<grpc_event_engine::experimental::EventEngine> event_engine_;
1499   std::shared_ptr<WorkSerializer> work_serializer_;
1500   FakeHelper* helper_ = nullptr;
1501   std::map<SubchannelKey, SubchannelState> subchannel_pool_;
1502   OrphanablePtr<LoadBalancingPolicy> lb_policy_;
1503   const absl::string_view lb_policy_name_;
1504   const ChannelArgs channel_args_;
1505   GlobalStatsPluginRegistry::StatsPluginGroup stats_plugin_group_;
1506   std::string target_ = "dns:server.example.com";
1507   std::string authority_ = "server.example.com";
1508 };
1509 
1510 }  // namespace testing
1511 }  // namespace grpc_core
1512 
1513 #endif  // GRPC_TEST_CORE_CLIENT_CHANNEL_LB_POLICY_LB_POLICY_TEST_LIB_H
1514