1 // 2 // Copyright 2022 gRPC authors. 3 // 4 // Licensed under the Apache License, Version 2.0 (the "License"); 5 // you may not use this file except in compliance with the License. 6 // You may obtain a copy of the License at 7 // 8 // http://www.apache.org/licenses/LICENSE-2.0 9 // 10 // Unless required by applicable law or agreed to in writing, software 11 // distributed under the License is distributed on an "AS IS" BASIS, 12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 // See the License for the specific language governing permissions and 14 // limitations under the License. 15 // 16 17 #ifndef GRPC_TEST_CORE_CLIENT_CHANNEL_LB_POLICY_LB_POLICY_TEST_LIB_H 18 #define GRPC_TEST_CORE_CLIENT_CHANNEL_LB_POLICY_LB_POLICY_TEST_LIB_H 19 20 #include <inttypes.h> 21 #include <stddef.h> 22 23 #include <algorithm> 24 #include <chrono> 25 #include <deque> 26 #include <functional> 27 #include <map> 28 #include <memory> 29 #include <set> 30 #include <string> 31 #include <tuple> 32 #include <type_traits> 33 #include <utility> 34 #include <vector> 35 36 #include "absl/base/thread_annotations.h" 37 #include "absl/functional/any_invocable.h" 38 #include "absl/status/status.h" 39 #include "absl/status/statusor.h" 40 #include "absl/strings/str_format.h" 41 #include "absl/strings/str_join.h" 42 #include "absl/strings/string_view.h" 43 #include "absl/synchronization/notification.h" 44 #include "absl/types/optional.h" 45 #include "absl/types/span.h" 46 #include "absl/types/variant.h" 47 #include "gmock/gmock.h" 48 #include "gtest/gtest.h" 49 50 #include <grpc/event_engine/event_engine.h> 51 #include <grpc/grpc.h> 52 #include <grpc/support/alloc.h> 53 #include <grpc/support/log.h> 54 #include <grpc/support/port_platform.h> 55 56 #include "src/core/client_channel/client_channel_internal.h" 57 #include "src/core/client_channel/subchannel_interface_internal.h" 58 #include "src/core/client_channel/subchannel_pool_interface.h" 59 #include "src/core/lib/address_utils/parse_address.h" 60 #include "src/core/lib/address_utils/sockaddr_utils.h" 61 #include "src/core/lib/channel/channel_args.h" 62 #include "src/core/lib/config/core_configuration.h" 63 #include "src/core/lib/event_engine/default_event_engine.h" 64 #include "src/core/lib/gprpp/debug_location.h" 65 #include "src/core/lib/gprpp/match.h" 66 #include "src/core/lib/gprpp/orphanable.h" 67 #include "src/core/lib/gprpp/ref_counted_ptr.h" 68 #include "src/core/lib/gprpp/sync.h" 69 #include "src/core/lib/gprpp/time.h" 70 #include "src/core/lib/gprpp/unique_type_name.h" 71 #include "src/core/lib/gprpp/work_serializer.h" 72 #include "src/core/lib/iomgr/exec_ctx.h" 73 #include "src/core/lib/iomgr/resolved_address.h" 74 #include "src/core/lib/json/json.h" 75 #include "src/core/lib/security/credentials/credentials.h" 76 #include "src/core/lib/transport/connectivity_state.h" 77 #include "src/core/lib/uri/uri_parser.h" 78 #include "src/core/load_balancing/backend_metric_data.h" 79 #include "src/core/load_balancing/health_check_client_internal.h" 80 #include "src/core/load_balancing/lb_policy.h" 81 #include "src/core/load_balancing/lb_policy_registry.h" 82 #include "src/core/load_balancing/oob_backend_metric.h" 83 #include "src/core/load_balancing/oob_backend_metric_internal.h" 84 #include "src/core/load_balancing/subchannel_interface.h" 85 #include "src/core/resolver/endpoint_addresses.h" 86 #include "src/core/service_config/service_config_call_data.h" 87 #include "test/core/event_engine/event_engine_test_utils.h" 88 #include "test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.h" 89 #include "test/core/event_engine/fuzzing_event_engine/fuzzing_event_engine.pb.h" 90 91 namespace grpc_core { 92 namespace testing { 93 94 class LoadBalancingPolicyTest : public ::testing::Test { 95 protected: 96 using CallAttributes = 97 std::vector<ServiceConfigCallData::CallAttributeInterface*>; 98 99 // Channel-level subchannel state for a specific address and channel args. 100 // This is analogous to the real subchannel in the ClientChannel code. 101 class SubchannelState { 102 public: 103 // A fake SubchannelInterface object, to be returned to the LB 104 // policy when it calls the helper's CreateSubchannel() method. 105 // There may be multiple FakeSubchannel objects associated with a 106 // given SubchannelState object. 107 class FakeSubchannel : public SubchannelInterface { 108 public: FakeSubchannel(SubchannelState * state)109 explicit FakeSubchannel(SubchannelState* state) : state_(state) {} 110 ~FakeSubchannel()111 ~FakeSubchannel() override { 112 if (orca_watcher_ != nullptr) { 113 MutexLock lock(&state_->backend_metric_watcher_mu_); 114 state_->orca_watchers_.erase(orca_watcher_.get()); 115 } 116 for (const auto& p : watcher_map_) { 117 state_->state_tracker_.RemoveWatcher(p.second); 118 } 119 } 120 state()121 SubchannelState* state() const { return state_; } 122 123 private: 124 // Converts between 125 // SubchannelInterface::ConnectivityStateWatcherInterface and 126 // ConnectivityStateWatcherInterface. 127 // 128 // We support both unique_ptr<> and shared_ptr<>, since raw 129 // connectivity watches use the latter but health watches use the 130 // former. 131 // TODO(roth): Clean this up. 132 class WatcherWrapper : public AsyncConnectivityStateWatcherInterface { 133 public: WatcherWrapper(std::shared_ptr<WorkSerializer> work_serializer,std::unique_ptr<SubchannelInterface::ConnectivityStateWatcherInterface> watcher)134 WatcherWrapper( 135 std::shared_ptr<WorkSerializer> work_serializer, 136 std::unique_ptr< 137 SubchannelInterface::ConnectivityStateWatcherInterface> 138 watcher) 139 : AsyncConnectivityStateWatcherInterface( 140 std::move(work_serializer)), 141 watcher_(std::move(watcher)) {} 142 WatcherWrapper(std::shared_ptr<WorkSerializer> work_serializer,std::shared_ptr<SubchannelInterface::ConnectivityStateWatcherInterface> watcher)143 WatcherWrapper( 144 std::shared_ptr<WorkSerializer> work_serializer, 145 std::shared_ptr< 146 SubchannelInterface::ConnectivityStateWatcherInterface> 147 watcher) 148 : AsyncConnectivityStateWatcherInterface( 149 std::move(work_serializer)), 150 watcher_(std::move(watcher)) {} 151 OnConnectivityStateChange(grpc_connectivity_state new_state,const absl::Status & status)152 void OnConnectivityStateChange(grpc_connectivity_state new_state, 153 const absl::Status& status) override { 154 gpr_log(GPR_INFO, "notifying watcher: state=%s status=%s", 155 ConnectivityStateName(new_state), status.ToString().c_str()); 156 watcher_->OnConnectivityStateChange(new_state, status); 157 } 158 159 private: 160 std::shared_ptr<SubchannelInterface::ConnectivityStateWatcherInterface> 161 watcher_; 162 }; 163 WatchConnectivityState(std::unique_ptr<SubchannelInterface::ConnectivityStateWatcherInterface> watcher)164 void WatchConnectivityState( 165 std::unique_ptr< 166 SubchannelInterface::ConnectivityStateWatcherInterface> 167 watcher) override 168 ABSL_EXCLUSIVE_LOCKS_REQUIRED(*state_->test_->work_serializer_) { 169 auto* watcher_ptr = watcher.get(); 170 auto watcher_wrapper = MakeOrphanable<WatcherWrapper>( 171 state_->work_serializer(), std::move(watcher)); 172 watcher_map_[watcher_ptr] = watcher_wrapper.get(); 173 state_->state_tracker_.AddWatcher(GRPC_CHANNEL_SHUTDOWN, 174 std::move(watcher_wrapper)); 175 } 176 CancelConnectivityStateWatch(ConnectivityStateWatcherInterface * watcher)177 void CancelConnectivityStateWatch( 178 ConnectivityStateWatcherInterface* watcher) override 179 ABSL_EXCLUSIVE_LOCKS_REQUIRED(*state_->test_->work_serializer_) { 180 auto it = watcher_map_.find(watcher); 181 if (it == watcher_map_.end()) return; 182 state_->state_tracker_.RemoveWatcher(it->second); 183 watcher_map_.erase(it); 184 } 185 RequestConnection()186 void RequestConnection() override { 187 MutexLock lock(&state_->requested_connection_mu_); 188 state_->requested_connection_ = true; 189 } 190 AddDataWatcher(std::unique_ptr<DataWatcherInterface> watcher)191 void AddDataWatcher( 192 std::unique_ptr<DataWatcherInterface> watcher) override 193 ABSL_EXCLUSIVE_LOCKS_REQUIRED(*state_->test_->work_serializer_) { 194 MutexLock lock(&state_->backend_metric_watcher_mu_); 195 auto* w = 196 static_cast<InternalSubchannelDataWatcherInterface*>(watcher.get()); 197 if (w->type() == OrcaProducer::Type()) { 198 GPR_ASSERT(orca_watcher_ == nullptr); 199 orca_watcher_.reset(static_cast<OrcaWatcher*>(watcher.release())); 200 state_->orca_watchers_.insert(orca_watcher_.get()); 201 } else if (w->type() == HealthProducer::Type()) { 202 // TODO(roth): Support health checking in test framework. 203 // For now, we just hard-code this to the raw connectivity state. 204 GPR_ASSERT(health_watcher_ == nullptr); 205 GPR_ASSERT(health_watcher_wrapper_ == nullptr); 206 health_watcher_.reset(static_cast<HealthWatcher*>(watcher.release())); 207 auto connectivity_watcher = health_watcher_->TakeWatcher(); 208 auto* connectivity_watcher_ptr = connectivity_watcher.get(); 209 auto watcher_wrapper = MakeOrphanable<WatcherWrapper>( 210 state_->work_serializer(), std::move(connectivity_watcher)); 211 health_watcher_wrapper_ = watcher_wrapper.get(); 212 state_->state_tracker_.AddWatcher(GRPC_CHANNEL_SHUTDOWN, 213 std::move(watcher_wrapper)); 214 gpr_log(GPR_INFO, 215 "AddDataWatcher(): added HealthWatch=%p " 216 "connectivity_watcher=%p watcher_wrapper=%p", 217 health_watcher_.get(), connectivity_watcher_ptr, 218 health_watcher_wrapper_); 219 } 220 } 221 CancelDataWatcher(DataWatcherInterface * watcher)222 void CancelDataWatcher(DataWatcherInterface* watcher) override 223 ABSL_EXCLUSIVE_LOCKS_REQUIRED(*state_->test_->work_serializer_) { 224 MutexLock lock(&state_->backend_metric_watcher_mu_); 225 auto* w = static_cast<InternalSubchannelDataWatcherInterface*>(watcher); 226 if (w->type() == OrcaProducer::Type()) { 227 if (orca_watcher_.get() != static_cast<OrcaWatcher*>(watcher)) return; 228 state_->orca_watchers_.erase(orca_watcher_.get()); 229 orca_watcher_.reset(); 230 } else if (w->type() == HealthProducer::Type()) { 231 if (health_watcher_.get() != static_cast<HealthWatcher*>(watcher)) { 232 return; 233 } 234 gpr_log(GPR_INFO, 235 "CancelDataWatcher(): cancelling HealthWatch=%p " 236 "watcher_wrapper=%p", 237 health_watcher_.get(), health_watcher_wrapper_); 238 state_->state_tracker_.RemoveWatcher(health_watcher_wrapper_); 239 health_watcher_wrapper_ = nullptr; 240 health_watcher_.reset(); 241 } 242 } 243 244 // Don't need this method, so it's a no-op. ResetBackoff()245 void ResetBackoff() override {} 246 247 SubchannelState* state_; 248 std::map<SubchannelInterface::ConnectivityStateWatcherInterface*, 249 WatcherWrapper*> 250 watcher_map_; 251 std::unique_ptr<HealthWatcher> health_watcher_; 252 WatcherWrapper* health_watcher_wrapper_ = nullptr; 253 std::unique_ptr<OrcaWatcher> orca_watcher_; 254 }; 255 SubchannelState(absl::string_view address,LoadBalancingPolicyTest * test)256 SubchannelState(absl::string_view address, LoadBalancingPolicyTest* test) 257 : address_(address), 258 test_(test), 259 state_tracker_("LoadBalancingPolicyTest") {} 260 address()261 const std::string& address() const { return address_; } 262 263 void AssertValidConnectivityStateTransition( 264 grpc_connectivity_state from_state, grpc_connectivity_state to_state, 265 SourceLocation location = SourceLocation()) { 266 switch (from_state) { 267 case GRPC_CHANNEL_IDLE: 268 ASSERT_EQ(to_state, GRPC_CHANNEL_CONNECTING) 269 << ConnectivityStateName(from_state) << "=>" 270 << ConnectivityStateName(to_state) << "\n" 271 << location.file() << ":" << location.line(); 272 break; 273 case GRPC_CHANNEL_CONNECTING: 274 ASSERT_THAT(to_state, 275 ::testing::AnyOf(GRPC_CHANNEL_READY, 276 GRPC_CHANNEL_TRANSIENT_FAILURE)) 277 << ConnectivityStateName(from_state) << "=>" 278 << ConnectivityStateName(to_state) << "\n" 279 << location.file() << ":" << location.line(); 280 break; 281 case GRPC_CHANNEL_READY: 282 ASSERT_EQ(to_state, GRPC_CHANNEL_IDLE) 283 << ConnectivityStateName(from_state) << "=>" 284 << ConnectivityStateName(to_state) << "\n" 285 << location.file() << ":" << location.line(); 286 break; 287 case GRPC_CHANNEL_TRANSIENT_FAILURE: 288 ASSERT_EQ(to_state, GRPC_CHANNEL_IDLE) 289 << ConnectivityStateName(from_state) << "=>" 290 << ConnectivityStateName(to_state) << "\n" 291 << location.file() << ":" << location.line(); 292 break; 293 default: 294 FAIL() << ConnectivityStateName(from_state) << "=>" 295 << ConnectivityStateName(to_state) << "\n" 296 << location.file() << ":" << location.line(); 297 break; 298 } 299 } 300 301 // Sets the connectivity state for this subchannel. The updated state 302 // will be reported to all associated SubchannelInterface objects. 303 void SetConnectivityState(grpc_connectivity_state state, 304 const absl::Status& status = absl::OkStatus(), 305 bool validate_state_transition = true, 306 SourceLocation location = SourceLocation()) { 307 ExecCtx exec_ctx; 308 if (state == GRPC_CHANNEL_TRANSIENT_FAILURE) { 309 EXPECT_FALSE(status.ok()) 310 << "bug in test: TRANSIENT_FAILURE must have non-OK status"; 311 } else { 312 EXPECT_TRUE(status.ok()) 313 << "bug in test: " << ConnectivityStateName(state) 314 << " must have OK status: " << status; 315 } 316 // Updating the state in the state tracker will enqueue 317 // notifications to watchers on the WorkSerializer. If any 318 // subchannel reports READY, the pick_first leaf policy will then 319 // start a health watch, whose initial notification will also be 320 // scheduled on the WorkSerializer. We don't want to return until 321 // all of those notifications have been delivered. 322 absl::Notification notification; 323 test_->work_serializer_->Run( 324 [&]() ABSL_EXCLUSIVE_LOCKS_REQUIRED(*test_->work_serializer_) { 325 if (validate_state_transition) { 326 AssertValidConnectivityStateTransition(state_tracker_.state(), 327 state, location); 328 } 329 gpr_log(GPR_INFO, "Setting state on tracker"); 330 state_tracker_.SetState(state, status, "set from test"); 331 // SetState() enqueued the connectivity state notifications for 332 // the subchannel, so we add another callback to the queue to be 333 // executed after that state notifications has been delivered. 334 gpr_log(GPR_INFO, 335 "Waiting for state notifications to be delivered"); 336 test_->work_serializer_->Run( 337 [&]() { 338 gpr_log(GPR_INFO, 339 "State notifications delivered, waiting for health " 340 "notifications"); 341 // Now the connectivity state notifications has been 342 // delivered. If the state reported was READY, then the 343 // pick_first leaf policy will have started a health watch, so 344 // we add another callback to the queue to be executed after 345 // the initial health watch notification has been delivered. 346 test_->work_serializer_->Run([&]() { notification.Notify(); }, 347 DEBUG_LOCATION); 348 }, 349 DEBUG_LOCATION); 350 }, 351 DEBUG_LOCATION); 352 notification.WaitForNotification(); 353 gpr_log(GPR_INFO, "Health notifications delivered"); 354 } 355 356 // Indicates if any of the associated SubchannelInterface objects 357 // have requested a connection attempt since the last time this 358 // method was called. ConnectionRequested()359 bool ConnectionRequested() { 360 MutexLock lock(&requested_connection_mu_); 361 return std::exchange(requested_connection_, false); 362 } 363 364 // To be invoked by FakeHelper. CreateSubchannel()365 RefCountedPtr<SubchannelInterface> CreateSubchannel() { 366 return MakeRefCounted<FakeSubchannel>(this); 367 } 368 369 // Sends an OOB backend metric report to all watchers. SendOobBackendMetricReport(const BackendMetricData & backend_metrics)370 void SendOobBackendMetricReport(const BackendMetricData& backend_metrics) { 371 MutexLock lock(&backend_metric_watcher_mu_); 372 for (const auto* watcher : orca_watchers_) { 373 watcher->watcher()->OnBackendMetricReport(backend_metrics); 374 } 375 } 376 377 // Checks that all OOB watchers have the expected reporting period. 378 void CheckOobReportingPeriod(Duration expected, 379 SourceLocation location = SourceLocation()) { 380 MutexLock lock(&backend_metric_watcher_mu_); 381 for (const auto* watcher : orca_watchers_) { 382 EXPECT_EQ(watcher->report_interval(), expected) 383 << location.file() << ":" << location.line(); 384 } 385 } 386 NumWatchers()387 size_t NumWatchers() const { 388 size_t num_watchers; 389 absl::Notification notification; 390 work_serializer()->Run( 391 [&]() ABSL_EXCLUSIVE_LOCKS_REQUIRED(*test_->work_serializer_) { 392 num_watchers = state_tracker_.NumWatchers(); 393 notification.Notify(); 394 }, 395 DEBUG_LOCATION); 396 notification.WaitForNotification(); 397 return num_watchers; 398 } 399 work_serializer()400 std::shared_ptr<WorkSerializer> work_serializer() const { 401 return test_->work_serializer_; 402 } 403 state_tracker()404 ConnectivityStateTracker& state_tracker() { return state_tracker_; } 405 406 private: 407 const std::string address_; 408 LoadBalancingPolicyTest* const test_; 409 ConnectivityStateTracker state_tracker_ 410 ABSL_GUARDED_BY(*test_->work_serializer_); 411 412 Mutex requested_connection_mu_; 413 bool requested_connection_ ABSL_GUARDED_BY(&requested_connection_mu_) = 414 false; 415 416 Mutex backend_metric_watcher_mu_; 417 std::set<OrcaWatcher*> orca_watchers_ 418 ABSL_GUARDED_BY(&backend_metric_watcher_mu_); 419 }; 420 421 // A fake helper to be passed to the LB policy. 422 class FakeHelper : public LoadBalancingPolicy::ChannelControlHelper { 423 public: 424 // Represents a state update reported by the LB policy. 425 struct StateUpdate { 426 grpc_connectivity_state state; 427 absl::Status status; 428 RefCountedPtr<LoadBalancingPolicy::SubchannelPicker> picker; 429 ToStringStateUpdate430 std::string ToString() const { 431 return absl::StrFormat("UPDATE{state=%s, status=%s, picker=%p}", 432 ConnectivityStateName(state), status.ToString(), 433 picker.get()); 434 } 435 }; 436 437 // Represents a re-resolution request from the LB policy. 438 struct ReresolutionRequested { ToStringReresolutionRequested439 std::string ToString() const { return "RERESOLUTION"; } 440 }; 441 FakeHelper(LoadBalancingPolicyTest * test)442 explicit FakeHelper(LoadBalancingPolicyTest* test) : test_(test) {} 443 QueueEmpty()444 bool QueueEmpty() { 445 MutexLock lock(&mu_); 446 return queue_.empty(); 447 } 448 449 // Called at test tear-down time to ensure that we have not left any 450 // unexpected events in the queue. 451 void ExpectQueueEmpty(SourceLocation location = SourceLocation()) { 452 MutexLock lock(&mu_); 453 EXPECT_TRUE(queue_.empty()) 454 << location.file() << ":" << location.line() << "\n" 455 << QueueString(); 456 } 457 458 // Returns the next event in the queue if it is a state update. 459 // If the queue is empty or the next event is not a state update, 460 // fails the test and returns nullopt without removing anything from 461 // the queue. 462 absl::optional<StateUpdate> GetNextStateUpdate( 463 SourceLocation location = SourceLocation()) { 464 MutexLock lock(&mu_); 465 EXPECT_FALSE(queue_.empty()) << location.file() << ":" << location.line(); 466 if (queue_.empty()) return absl::nullopt; 467 Event& event = queue_.front(); 468 auto* update = absl::get_if<StateUpdate>(&event); 469 EXPECT_NE(update, nullptr) 470 << "unexpected event " << EventString(event) << " at " 471 << location.file() << ":" << location.line(); 472 if (update == nullptr) return absl::nullopt; 473 StateUpdate result = std::move(*update); 474 gpr_log(GPR_INFO, "dequeued next state update: %s", 475 result.ToString().c_str()); 476 queue_.pop_front(); 477 return std::move(result); 478 } 479 480 // Returns the next event in the queue if it is a re-resolution. 481 // If the queue is empty or the next event is not a re-resolution, 482 // fails the test and returns nullopt without removing anything 483 // from the queue. 484 absl::optional<ReresolutionRequested> GetNextReresolution( 485 SourceLocation location = SourceLocation()) { 486 MutexLock lock(&mu_); 487 EXPECT_FALSE(queue_.empty()) << location.file() << ":" << location.line(); 488 if (queue_.empty()) return absl::nullopt; 489 Event& event = queue_.front(); 490 auto* reresolution = absl::get_if<ReresolutionRequested>(&event); 491 EXPECT_NE(reresolution, nullptr) 492 << "unexpected event " << EventString(event) << " at " 493 << location.file() << ":" << location.line(); 494 if (reresolution == nullptr) return absl::nullopt; 495 ReresolutionRequested result = *reresolution; 496 queue_.pop_front(); 497 return result; 498 } 499 500 private: 501 // A wrapper for a picker that hops into the WorkSerializer to 502 // release the ref to the picker. 503 class PickerWrapper : public LoadBalancingPolicy::SubchannelPicker { 504 public: PickerWrapper(LoadBalancingPolicyTest * test,RefCountedPtr<LoadBalancingPolicy::SubchannelPicker> picker)505 PickerWrapper(LoadBalancingPolicyTest* test, 506 RefCountedPtr<LoadBalancingPolicy::SubchannelPicker> picker) 507 : test_(test), picker_(std::move(picker)) { 508 gpr_log(GPR_INFO, "creating wrapper %p for picker %p", this, 509 picker_.get()); 510 } 511 Orphaned()512 void Orphaned() override { 513 absl::Notification notification; 514 ExecCtx exec_ctx; 515 test_->work_serializer_->Run( 516 [notification = ¬ification, 517 picker = std::move(picker_)]() mutable { 518 picker.reset(); 519 notification->Notify(); 520 }, 521 DEBUG_LOCATION); 522 notification.WaitForNotification(); 523 } 524 Pick(LoadBalancingPolicy::PickArgs args)525 LoadBalancingPolicy::PickResult Pick( 526 LoadBalancingPolicy::PickArgs args) override { 527 return picker_->Pick(args); 528 } 529 530 private: 531 LoadBalancingPolicyTest* const test_; 532 RefCountedPtr<LoadBalancingPolicy::SubchannelPicker> picker_; 533 }; 534 535 // Represents an event reported by the LB policy. 536 using Event = absl::variant<StateUpdate, ReresolutionRequested>; 537 538 // Returns a human-readable representation of an event. EventString(const Event & event)539 static std::string EventString(const Event& event) { 540 return Match( 541 event, [](const StateUpdate& update) { return update.ToString(); }, 542 [](const ReresolutionRequested& reresolution) { 543 return reresolution.ToString(); 544 }); 545 } 546 QueueString()547 std::string QueueString() const ABSL_EXCLUSIVE_LOCKS_REQUIRED(&mu_) { 548 std::vector<std::string> parts = {"Queue:"}; 549 for (const Event& event : queue_) { 550 parts.push_back(EventString(event)); 551 } 552 return absl::StrJoin(parts, "\n "); 553 } 554 CreateSubchannel(const grpc_resolved_address & address,const ChannelArgs &,const ChannelArgs & args)555 RefCountedPtr<SubchannelInterface> CreateSubchannel( 556 const grpc_resolved_address& address, 557 const ChannelArgs& /*per_address_args*/, 558 const ChannelArgs& args) override { 559 // TODO(roth): Need to use per_address_args here. 560 SubchannelKey key( 561 address, args.RemoveAllKeysWithPrefix(GRPC_ARG_NO_SUBCHANNEL_PREFIX)); 562 auto it = test_->subchannel_pool_.find(key); 563 if (it == test_->subchannel_pool_.end()) { 564 auto address_uri = grpc_sockaddr_to_uri(&address); 565 GPR_ASSERT(address_uri.ok()); 566 it = test_->subchannel_pool_ 567 .emplace(std::piecewise_construct, std::forward_as_tuple(key), 568 std::forward_as_tuple(std::move(*address_uri), test_)) 569 .first; 570 } 571 return it->second.CreateSubchannel(); 572 } 573 UpdateState(grpc_connectivity_state state,const absl::Status & status,RefCountedPtr<LoadBalancingPolicy::SubchannelPicker> picker)574 void UpdateState( 575 grpc_connectivity_state state, const absl::Status& status, 576 RefCountedPtr<LoadBalancingPolicy::SubchannelPicker> picker) override { 577 MutexLock lock(&mu_); 578 StateUpdate update{ 579 state, status, 580 MakeRefCounted<PickerWrapper>(test_, std::move(picker))}; 581 gpr_log(GPR_INFO, "enqueuing state update from LB policy: %s", 582 update.ToString().c_str()); 583 queue_.push_back(std::move(update)); 584 } 585 RequestReresolution()586 void RequestReresolution() override { 587 MutexLock lock(&mu_); 588 queue_.push_back(ReresolutionRequested()); 589 } 590 GetTarget()591 absl::string_view GetTarget() override { return test_->target_; } 592 GetAuthority()593 absl::string_view GetAuthority() override { return test_->authority_; } 594 GetChannelCredentials()595 RefCountedPtr<grpc_channel_credentials> GetChannelCredentials() override { 596 return nullptr; 597 } 598 GetUnsafeChannelCredentials()599 RefCountedPtr<grpc_channel_credentials> GetUnsafeChannelCredentials() 600 override { 601 return nullptr; 602 } 603 GetEventEngine()604 grpc_event_engine::experimental::EventEngine* GetEventEngine() override { 605 return test_->fuzzing_ee_.get(); 606 } 607 GetStatsPluginGroup()608 GlobalStatsPluginRegistry::StatsPluginGroup& GetStatsPluginGroup() 609 override { 610 return test_->stats_plugin_group_; 611 } 612 AddTraceEvent(TraceSeverity,absl::string_view)613 void AddTraceEvent(TraceSeverity, absl::string_view) override {} 614 615 LoadBalancingPolicyTest* test_; 616 617 Mutex mu_; 618 std::deque<Event> queue_ ABSL_GUARDED_BY(&mu_); 619 }; 620 621 // A fake MetadataInterface implementation, for use in PickArgs. 622 class FakeMetadata : public LoadBalancingPolicy::MetadataInterface { 623 public: FakeMetadata(std::map<std::string,std::string> metadata)624 explicit FakeMetadata(std::map<std::string, std::string> metadata) 625 : metadata_(std::move(metadata)) {} 626 metadata()627 const std::map<std::string, std::string>& metadata() const { 628 return metadata_; 629 } 630 631 private: Add(absl::string_view key,absl::string_view value)632 void Add(absl::string_view key, absl::string_view value) override { 633 metadata_[std::string(key)] = std::string(value); 634 } 635 TestOnlyCopyToVector()636 std::vector<std::pair<std::string, std::string>> TestOnlyCopyToVector() 637 override { 638 return {}; // Not used. 639 } 640 Lookup(absl::string_view key,std::string *)641 absl::optional<absl::string_view> Lookup( 642 absl::string_view key, std::string* /*buffer*/) const override { 643 auto it = metadata_.find(std::string(key)); 644 if (it == metadata_.end()) return absl::nullopt; 645 return it->second; 646 } 647 648 std::map<std::string, std::string> metadata_; 649 }; 650 651 // A fake CallState implementation, for use in PickArgs. 652 class FakeCallState : public ClientChannelLbCallState { 653 public: FakeCallState(const CallAttributes & attributes)654 explicit FakeCallState(const CallAttributes& attributes) { 655 for (const auto& attribute : attributes) { 656 attributes_.emplace(attribute->type(), attribute); 657 } 658 } 659 ~FakeCallState()660 ~FakeCallState() override { 661 for (void* allocation : allocations_) { 662 gpr_free(allocation); 663 } 664 } 665 666 private: Alloc(size_t size)667 void* Alloc(size_t size) override { 668 void* allocation = gpr_malloc(size); 669 allocations_.push_back(allocation); 670 return allocation; 671 } 672 GetCallAttribute(UniqueTypeName type)673 ServiceConfigCallData::CallAttributeInterface* GetCallAttribute( 674 UniqueTypeName type) const override { 675 auto it = attributes_.find(type); 676 if (it != attributes_.end()) { 677 return it->second; 678 } 679 return nullptr; 680 } 681 GetCallAttemptTracer()682 ClientCallTracer::CallAttemptTracer* GetCallAttemptTracer() const override { 683 return nullptr; 684 } 685 686 std::vector<void*> allocations_; 687 std::map<UniqueTypeName, ServiceConfigCallData::CallAttributeInterface*> 688 attributes_; 689 }; 690 691 // A fake BackendMetricAccessor implementation, for passing to 692 // SubchannelCallTrackerInterface::Finish(). 693 class FakeBackendMetricAccessor 694 : public LoadBalancingPolicy::BackendMetricAccessor { 695 public: FakeBackendMetricAccessor(absl::optional<BackendMetricData> backend_metric_data)696 explicit FakeBackendMetricAccessor( 697 absl::optional<BackendMetricData> backend_metric_data) 698 : backend_metric_data_(std::move(backend_metric_data)) {} 699 GetBackendMetricData()700 const BackendMetricData* GetBackendMetricData() override { 701 if (backend_metric_data_.has_value()) return &*backend_metric_data_; 702 return nullptr; 703 } 704 705 private: 706 const absl::optional<BackendMetricData> backend_metric_data_; 707 }; 708 709 explicit LoadBalancingPolicyTest(absl::string_view lb_policy_name, 710 ChannelArgs channel_args = ChannelArgs()) lb_policy_name_(lb_policy_name)711 : lb_policy_name_(lb_policy_name), 712 channel_args_(std::move(channel_args)) {} 713 SetUp()714 void SetUp() override { 715 // Order is important here: Fuzzing EE needs to be created before 716 // grpc_init(), and the POSIX EE (which is used by the WorkSerializer) 717 // needs to be created after grpc_init(). 718 fuzzing_ee_ = 719 std::make_shared<grpc_event_engine::experimental::FuzzingEventEngine>( 720 grpc_event_engine::experimental::FuzzingEventEngine::Options(), 721 fuzzing_event_engine::Actions()); 722 grpc_init(); 723 event_engine_ = grpc_event_engine::experimental::GetDefaultEventEngine(); 724 work_serializer_ = std::make_shared<WorkSerializer>(event_engine_); 725 auto helper = std::make_unique<FakeHelper>(this); 726 helper_ = helper.get(); 727 LoadBalancingPolicy::Args args = {work_serializer_, std::move(helper), 728 channel_args_}; 729 lb_policy_ = 730 CoreConfiguration::Get().lb_policy_registry().CreateLoadBalancingPolicy( 731 lb_policy_name_, std::move(args)); 732 GPR_ASSERT(lb_policy_ != nullptr); 733 } 734 TearDown()735 void TearDown() override { 736 ExecCtx exec_ctx; 737 fuzzing_ee_->FuzzingDone(); 738 // Make sure pickers (and transitively, subchannels) are unreffed before 739 // destroying the fixture. 740 WaitForWorkSerializerToFlush(); 741 work_serializer_.reset(); 742 exec_ctx.Flush(); 743 // Note: Can't safely trigger this from inside the FakeHelper dtor, 744 // because if there is a picker in the queue that is holding a ref 745 // to the LB policy, that will prevent the LB policy from being 746 // destroyed, and therefore the FakeHelper will not be destroyed. 747 // (This will cause an ASAN failure, but it will not display the 748 // queued events, so the failure will be harder to diagnose.) 749 helper_->ExpectQueueEmpty(); 750 lb_policy_.reset(); 751 fuzzing_ee_->TickUntilIdle(); 752 grpc_event_engine::experimental::WaitForSingleOwner( 753 std::move(event_engine_)); 754 event_engine_.reset(); 755 grpc_shutdown_blocking(); 756 fuzzing_ee_.reset(); 757 } 758 lb_policy()759 LoadBalancingPolicy* lb_policy() const { 760 GPR_ASSERT(lb_policy_ != nullptr); 761 return lb_policy_.get(); 762 } 763 764 // Creates an LB policy config from json. 765 static RefCountedPtr<LoadBalancingPolicy::Config> MakeConfig( 766 const Json& json, SourceLocation location = SourceLocation()) { 767 auto status_or_config = 768 CoreConfiguration::Get().lb_policy_registry().ParseLoadBalancingConfig( 769 json); 770 EXPECT_TRUE(status_or_config.ok()) 771 << status_or_config.status() << "\n" 772 << location.file() << ":" << location.line(); 773 return status_or_config.value(); 774 } 775 776 // Converts an address URI into a grpc_resolved_address. MakeAddress(absl::string_view address_uri)777 static grpc_resolved_address MakeAddress(absl::string_view address_uri) { 778 auto uri = URI::Parse(address_uri); 779 GPR_ASSERT(uri.ok()); 780 grpc_resolved_address address; 781 GPR_ASSERT(grpc_parse_uri(*uri, &address)); 782 return address; 783 } 784 MakeAddressList(absl::Span<const absl::string_view> addresses)785 std::vector<grpc_resolved_address> MakeAddressList( 786 absl::Span<const absl::string_view> addresses) { 787 std::vector<grpc_resolved_address> addrs; 788 for (const absl::string_view& address : addresses) { 789 addrs.emplace_back(MakeAddress(address)); 790 } 791 return addrs; 792 } 793 794 EndpointAddresses MakeEndpointAddresses( 795 absl::Span<const absl::string_view> addresses, 796 const ChannelArgs& args = ChannelArgs()) { 797 return EndpointAddresses(MakeAddressList(addresses), args); 798 } 799 800 // Constructs an update containing a list of endpoints. 801 LoadBalancingPolicy::UpdateArgs BuildUpdate( 802 absl::Span<const EndpointAddresses> endpoints, 803 RefCountedPtr<LoadBalancingPolicy::Config> config, 804 ChannelArgs args = ChannelArgs()) { 805 LoadBalancingPolicy::UpdateArgs update; 806 update.addresses = std::make_shared<EndpointAddressesListIterator>( 807 EndpointAddressesList(endpoints.begin(), endpoints.end())); 808 update.config = std::move(config); 809 update.args = std::move(args); 810 return update; 811 } 812 MakeEndpointAddressesListFromAddressList(absl::Span<const absl::string_view> addresses)813 std::vector<EndpointAddresses> MakeEndpointAddressesListFromAddressList( 814 absl::Span<const absl::string_view> addresses) { 815 std::vector<EndpointAddresses> endpoints; 816 for (const absl::string_view address : addresses) { 817 endpoints.emplace_back(MakeAddress(address), ChannelArgs()); 818 } 819 return endpoints; 820 } 821 822 // Convenient overload that takes a flat address list. 823 LoadBalancingPolicy::UpdateArgs BuildUpdate( 824 absl::Span<const absl::string_view> addresses, 825 RefCountedPtr<LoadBalancingPolicy::Config> config, 826 ChannelArgs args = ChannelArgs()) { 827 return BuildUpdate(MakeEndpointAddressesListFromAddressList(addresses), 828 std::move(config), std::move(args)); 829 } 830 831 // Applies the update on the LB policy. ApplyUpdate(LoadBalancingPolicy::UpdateArgs update_args,LoadBalancingPolicy * lb_policy)832 absl::Status ApplyUpdate(LoadBalancingPolicy::UpdateArgs update_args, 833 LoadBalancingPolicy* lb_policy) { 834 ExecCtx exec_ctx; 835 absl::Status status; 836 // When the LB policy gets the update, it will create new 837 // subchannels, and it will register connectivity state watchers and 838 // optionally health watchers for each one. We don't want to return 839 // until all the initial notifications for all of those watchers 840 // have been delivered to the LB policy. 841 absl::Notification notification; 842 work_serializer_->Run( 843 [&]() { 844 status = lb_policy->UpdateLocked(std::move(update_args)); 845 // UpdateLocked() enqueued the initial connectivity state 846 // notifications for the subchannels, so we add another 847 // callback to the queue to be executed after those initial 848 // state notifications have been delivered. 849 gpr_log(GPR_INFO, 850 "Applied update, waiting for initial connectivity state " 851 "notifications"); 852 work_serializer_->Run( 853 [&]() { 854 gpr_log(GPR_INFO, 855 "Initial connectivity state notifications delivered; " 856 "waiting for health notifications"); 857 // Now that the initial state notifications have been 858 // delivered, the queue will contain the health watch 859 // notifications for any subchannels in state READY, 860 // so we add another callback to the queue to be 861 // executed after those health watch notifications have 862 // been delivered. 863 work_serializer_->Run([&]() { notification.Notify(); }, 864 DEBUG_LOCATION); 865 }, 866 DEBUG_LOCATION); 867 }, 868 DEBUG_LOCATION); 869 notification.WaitForNotification(); 870 gpr_log(GPR_INFO, "health notifications delivered"); 871 return status; 872 } 873 874 // Invoke ExitIdle on the LB policy ExitIdle()875 void ExitIdle() { 876 ExecCtx exec_ctx; 877 absl::Notification notification; 878 // Note: ExitIdle() will enqueue a bunch of connectivity state 879 // notifications on the WorkSerializer, and we want to wait until 880 // those are delivered to the LB policy. 881 work_serializer_->Run( 882 [&]() { 883 lb_policy_->ExitIdleLocked(); 884 work_serializer_->Run([&]() { notification.Notify(); }, 885 DEBUG_LOCATION); 886 }, 887 DEBUG_LOCATION); 888 notification.WaitForNotification(); 889 } 890 891 void ExpectQueueEmpty(SourceLocation location = SourceLocation()) { 892 helper_->ExpectQueueEmpty(location); 893 } 894 895 // Keeps reading state updates until continue_predicate() returns false. 896 // Returns false if the helper reports no events or if the event is 897 // not a state update; otherwise (if continue_predicate() tells us to 898 // stop) returns true. 899 bool WaitForStateUpdate( 900 std::function<bool(FakeHelper::StateUpdate update)> continue_predicate, 901 SourceLocation location = SourceLocation()) { 902 gpr_log(GPR_INFO, "==> WaitForStateUpdate()"); 903 while (true) { 904 auto update = helper_->GetNextStateUpdate(location); 905 if (!update.has_value()) { 906 gpr_log(GPR_INFO, "WaitForStateUpdate() returning false"); 907 return false; 908 } 909 if (!continue_predicate(std::move(*update))) { 910 gpr_log(GPR_INFO, "WaitForStateUpdate() returning true"); 911 return true; 912 } 913 } 914 } 915 916 void ExpectReresolutionRequest(SourceLocation location = SourceLocation()) { 917 ASSERT_TRUE(helper_->GetNextReresolution(location)) 918 << location.file() << ":" << location.line(); 919 } 920 921 // Expects that the LB policy has reported the specified connectivity 922 // state to helper_. Returns the picker from the state update. 923 RefCountedPtr<LoadBalancingPolicy::SubchannelPicker> ExpectState( 924 grpc_connectivity_state expected_state, 925 absl::Status expected_status = absl::OkStatus(), 926 SourceLocation location = SourceLocation()) { 927 auto update = helper_->GetNextStateUpdate(location); 928 if (!update.has_value()) return nullptr; 929 EXPECT_EQ(update->state, expected_state) 930 << "got " << ConnectivityStateName(update->state) << ", expected " 931 << ConnectivityStateName(expected_state) << "\n" 932 << "at " << location.file() << ":" << location.line(); 933 EXPECT_EQ(update->status, expected_status) 934 << update->status << "\n" 935 << location.file() << ":" << location.line(); 936 EXPECT_NE(update->picker, nullptr) 937 << location.file() << ":" << location.line(); 938 return std::move(update->picker); 939 } 940 941 // Waits for the LB policy to get connected, then returns the final 942 // picker. There can be any number of CONNECTING updates, each of 943 // which must return a picker that queues picks, followed by one 944 // update for state READY, whose picker is returned. 945 RefCountedPtr<LoadBalancingPolicy::SubchannelPicker> WaitForConnected( 946 SourceLocation location = SourceLocation()) { 947 gpr_log(GPR_INFO, "==> WaitForConnected()"); 948 RefCountedPtr<LoadBalancingPolicy::SubchannelPicker> final_picker; 949 WaitForStateUpdate( 950 [&](FakeHelper::StateUpdate update) { 951 if (update.state == GRPC_CHANNEL_CONNECTING) { 952 EXPECT_TRUE(update.status.ok()) 953 << update.status << " at " << location.file() << ":" 954 << location.line(); 955 ExpectPickQueued(update.picker.get(), {}, location); 956 return true; // Keep going. 957 } 958 EXPECT_EQ(update.state, GRPC_CHANNEL_READY) 959 << ConnectivityStateName(update.state) << " at " 960 << location.file() << ":" << location.line(); 961 final_picker = std::move(update.picker); 962 return false; // Stop. 963 }, 964 location); 965 return final_picker; 966 } 967 968 void ExpectTransientFailureUpdate( 969 absl::Status expected_status, 970 SourceLocation location = SourceLocation()) { 971 auto picker = 972 ExpectState(GRPC_CHANNEL_TRANSIENT_FAILURE, expected_status, location); 973 ASSERT_NE(picker, nullptr); 974 ExpectPickFail( 975 picker.get(), 976 [&](const absl::Status& status) { 977 EXPECT_EQ(status, expected_status) 978 << location.file() << ":" << location.line(); 979 }, 980 location); 981 } 982 983 // Waits for the LB policy to fail a connection attempt. There can be 984 // any number of CONNECTING updates, each of which must return a picker 985 // that queues picks, followed by one update for state TRANSIENT_FAILURE, 986 // whose status is passed to check_status() and whose picker must fail 987 // picks with a status that is passed to check_status(). 988 // Returns true if the reported states match expectations. 989 bool WaitForConnectionFailed( 990 std::function<void(const absl::Status&)> check_status, 991 SourceLocation location = SourceLocation()) { 992 bool retval = false; 993 WaitForStateUpdate( 994 [&](FakeHelper::StateUpdate update) { 995 if (update.state == GRPC_CHANNEL_CONNECTING) { 996 EXPECT_TRUE(update.status.ok()) 997 << update.status << " at " << location.file() << ":" 998 << location.line(); 999 ExpectPickQueued(update.picker.get(), {}, location); 1000 return true; // Keep going. 1001 } 1002 EXPECT_EQ(update.state, GRPC_CHANNEL_TRANSIENT_FAILURE) 1003 << ConnectivityStateName(update.state) << " at " 1004 << location.file() << ":" << location.line(); 1005 check_status(update.status); 1006 ExpectPickFail(update.picker.get(), check_status, location); 1007 retval = update.state == GRPC_CHANNEL_TRANSIENT_FAILURE; 1008 return false; // Stop. 1009 }, 1010 location); 1011 return retval; 1012 } 1013 1014 // Waits for the round_robin policy to start using an updated address list. 1015 // There can be any number of READY updates where the picker is still using 1016 // the old list followed by one READY update where the picker is using the 1017 // new list. Returns a picker if the reported states match expectations. 1018 RefCountedPtr<LoadBalancingPolicy::SubchannelPicker> 1019 WaitForRoundRobinListChange(absl::Span<const absl::string_view> old_addresses, 1020 absl::Span<const absl::string_view> new_addresses, 1021 const CallAttributes& call_attributes = {}, 1022 size_t num_iterations = 3, 1023 SourceLocation location = SourceLocation()) { 1024 gpr_log(GPR_INFO, "Waiting for expected RR addresses..."); 1025 RefCountedPtr<LoadBalancingPolicy::SubchannelPicker> retval; 1026 size_t num_picks = 1027 std::max(new_addresses.size(), old_addresses.size()) * num_iterations; 1028 WaitForStateUpdate( 1029 [&](FakeHelper::StateUpdate update) { 1030 EXPECT_EQ(update.state, GRPC_CHANNEL_READY) 1031 << location.file() << ":" << location.line(); 1032 if (update.state != GRPC_CHANNEL_READY) return false; 1033 // Get enough picks to round-robin num_iterations times across all 1034 // expected addresses. 1035 auto picks = GetCompletePicks(update.picker.get(), num_picks, 1036 call_attributes, nullptr, location); 1037 EXPECT_TRUE(picks.has_value()) 1038 << location.file() << ":" << location.line(); 1039 if (!picks.has_value()) return false; 1040 gpr_log(GPR_INFO, "PICKS: %s", absl::StrJoin(*picks, " ").c_str()); 1041 // If the picks still match the old list, then keep going. 1042 if (PicksAreRoundRobin(old_addresses, *picks)) return true; 1043 // Otherwise, the picks should match the new list. 1044 bool matches = PicksAreRoundRobin(new_addresses, *picks); 1045 EXPECT_TRUE(matches) 1046 << "Expected: " << absl::StrJoin(new_addresses, ", ") 1047 << "\nActual: " << absl::StrJoin(*picks, ", ") << "\nat " 1048 << location.file() << ":" << location.line(); 1049 if (matches) { 1050 retval = std::move(update.picker); 1051 } 1052 return false; // Stop. 1053 }, 1054 location); 1055 gpr_log(GPR_INFO, "done waiting for expected RR addresses"); 1056 return retval; 1057 } 1058 1059 // Expects a state update for the specified state and status, and then 1060 // expects the resulting picker to queue picks. 1061 bool ExpectStateAndQueuingPicker( 1062 grpc_connectivity_state expected_state, 1063 absl::Status expected_status = absl::OkStatus(), 1064 SourceLocation location = SourceLocation()) { 1065 auto picker = ExpectState(expected_state, expected_status, location); 1066 return ExpectPickQueued(picker.get(), {}, location); 1067 } 1068 1069 // Convenient frontend to ExpectStateAndQueuingPicker() for CONNECTING. 1070 bool ExpectConnectingUpdate(SourceLocation location = SourceLocation()) { 1071 return ExpectStateAndQueuingPicker(GRPC_CHANNEL_CONNECTING, 1072 absl::OkStatus(), location); 1073 } 1074 1075 static std::unique_ptr<LoadBalancingPolicy::MetadataInterface> MakeMetadata( 1076 std::map<std::string, std::string> init = {}) { 1077 return std::make_unique<FakeMetadata>(init); 1078 } 1079 1080 // Does a pick and returns the result. 1081 LoadBalancingPolicy::PickResult DoPick( 1082 LoadBalancingPolicy::SubchannelPicker* picker, 1083 const CallAttributes& call_attributes = {}) { 1084 ExecCtx exec_ctx; 1085 FakeMetadata metadata({}); 1086 FakeCallState call_state(call_attributes); 1087 return picker->Pick({"/service/method", &metadata, &call_state}); 1088 } 1089 1090 // Requests a pick on picker and expects a Queue result. 1091 bool ExpectPickQueued(LoadBalancingPolicy::SubchannelPicker* picker, 1092 const CallAttributes call_attributes = {}, 1093 SourceLocation location = SourceLocation()) { 1094 EXPECT_NE(picker, nullptr) << location.file() << ":" << location.line(); 1095 if (picker == nullptr) return false; 1096 auto pick_result = DoPick(picker, call_attributes); 1097 EXPECT_TRUE(absl::holds_alternative<LoadBalancingPolicy::PickResult::Queue>( 1098 pick_result.result)) 1099 << PickResultString(pick_result) << "\nat " << location.file() << ":" 1100 << location.line(); 1101 return absl::holds_alternative<LoadBalancingPolicy::PickResult::Queue>( 1102 pick_result.result); 1103 } 1104 1105 // Requests a pick on picker and expects a Complete result. 1106 // The address of the resulting subchannel is returned, or nullopt if 1107 // the result was something other than Complete. 1108 // If the complete pick includes a SubchannelCallTrackerInterface, then if 1109 // subchannel_call_tracker is non-null, it will be set to point to the 1110 // call tracker; otherwise, the call tracker will be invoked 1111 // automatically to represent a complete call with no backend metric data. 1112 absl::optional<std::string> ExpectPickComplete( 1113 LoadBalancingPolicy::SubchannelPicker* picker, 1114 const CallAttributes& call_attributes = {}, 1115 std::unique_ptr<LoadBalancingPolicy::SubchannelCallTrackerInterface>* 1116 subchannel_call_tracker = nullptr, 1117 SubchannelState::FakeSubchannel** picked_subchannel = nullptr, 1118 SourceLocation location = SourceLocation()) { 1119 EXPECT_NE(picker, nullptr); 1120 if (picker == nullptr) { 1121 return absl::nullopt; 1122 } 1123 auto pick_result = DoPick(picker, call_attributes); 1124 auto* complete = absl::get_if<LoadBalancingPolicy::PickResult::Complete>( 1125 &pick_result.result); 1126 EXPECT_NE(complete, nullptr) << PickResultString(pick_result) << " at " 1127 << location.file() << ":" << location.line(); 1128 if (complete == nullptr) return absl::nullopt; 1129 auto* subchannel = static_cast<SubchannelState::FakeSubchannel*>( 1130 complete->subchannel.get()); 1131 if (picked_subchannel != nullptr) *picked_subchannel = subchannel; 1132 std::string address = subchannel->state()->address(); 1133 if (complete->subchannel_call_tracker != nullptr) { 1134 if (subchannel_call_tracker != nullptr) { 1135 *subchannel_call_tracker = std::move(complete->subchannel_call_tracker); 1136 } else { 1137 ReportCompletionToCallTracker( 1138 std::move(complete->subchannel_call_tracker), address); 1139 } 1140 } 1141 return address; 1142 } 1143 1144 void ReportCompletionToCallTracker( 1145 std::unique_ptr<LoadBalancingPolicy::SubchannelCallTrackerInterface> 1146 subchannel_call_tracker, 1147 absl::string_view address, absl::Status status = absl::OkStatus()) { 1148 subchannel_call_tracker->Start(); 1149 FakeMetadata metadata({}); 1150 FakeBackendMetricAccessor backend_metric_accessor({}); 1151 LoadBalancingPolicy::SubchannelCallTrackerInterface::FinishArgs args = { 1152 address, status, &metadata, &backend_metric_accessor}; 1153 subchannel_call_tracker->Finish(args); 1154 } 1155 1156 // Gets num_picks complete picks from picker and returns the resulting 1157 // list of addresses, or nullopt if a non-complete pick was returned. 1158 absl::optional<std::vector<std::string>> GetCompletePicks( 1159 LoadBalancingPolicy::SubchannelPicker* picker, size_t num_picks, 1160 const CallAttributes& call_attributes = {}, 1161 std::vector< 1162 std::unique_ptr<LoadBalancingPolicy::SubchannelCallTrackerInterface>>* 1163 subchannel_call_trackers = nullptr, 1164 SourceLocation location = SourceLocation()) { 1165 EXPECT_NE(picker, nullptr); 1166 if (picker == nullptr) { 1167 return absl::nullopt; 1168 } 1169 std::vector<std::string> results; 1170 for (size_t i = 0; i < num_picks; ++i) { 1171 std::unique_ptr<LoadBalancingPolicy::SubchannelCallTrackerInterface> 1172 subchannel_call_tracker; 1173 auto address = ExpectPickComplete(picker, call_attributes, 1174 subchannel_call_trackers == nullptr 1175 ? nullptr 1176 : &subchannel_call_tracker, 1177 nullptr, location); 1178 if (!address.has_value()) return absl::nullopt; 1179 results.emplace_back(std::move(*address)); 1180 if (subchannel_call_trackers != nullptr) { 1181 subchannel_call_trackers->emplace_back( 1182 std::move(subchannel_call_tracker)); 1183 } 1184 } 1185 return results; 1186 } 1187 1188 // Returns true if the list of actual pick result addresses matches the 1189 // list of expected addresses for round_robin. Note that the actual 1190 // addresses may start anywhere in the list of expected addresses but 1191 // must then continue in round-robin fashion, with wrap-around. PicksAreRoundRobin(absl::Span<const absl::string_view> expected,absl::Span<const std::string> actual)1192 bool PicksAreRoundRobin(absl::Span<const absl::string_view> expected, 1193 absl::Span<const std::string> actual) { 1194 absl::optional<size_t> expected_index; 1195 for (const auto& address : actual) { 1196 auto it = std::find(expected.begin(), expected.end(), address); 1197 if (it == expected.end()) return false; 1198 size_t index = it - expected.begin(); 1199 if (expected_index.has_value() && index != *expected_index) return false; 1200 expected_index = (index + 1) % expected.size(); 1201 } 1202 return true; 1203 } 1204 1205 // Checks that the picker has round-robin behavior over the specified 1206 // set of addresses. 1207 void ExpectRoundRobinPicks(LoadBalancingPolicy::SubchannelPicker* picker, 1208 absl::Span<const absl::string_view> addresses, 1209 const CallAttributes& call_attributes = {}, 1210 size_t num_iterations = 3, 1211 SourceLocation location = SourceLocation()) { 1212 auto picks = GetCompletePicks(picker, num_iterations * addresses.size(), 1213 call_attributes, nullptr, location); 1214 ASSERT_TRUE(picks.has_value()) << location.file() << ":" << location.line(); 1215 EXPECT_TRUE(PicksAreRoundRobin(addresses, *picks)) 1216 << " Actual: " << absl::StrJoin(*picks, ", ") 1217 << "\n Expected: " << absl::StrJoin(addresses, ", ") << "\n" 1218 << location.file() << ":" << location.line(); 1219 } 1220 1221 // Expect startup with RR with a set of addresses. 1222 RefCountedPtr<LoadBalancingPolicy::SubchannelPicker> ExpectRoundRobinStartup( 1223 absl::Span<const EndpointAddresses> endpoints, 1224 SourceLocation location = SourceLocation()) { 1225 GPR_ASSERT(!endpoints.empty()); 1226 // There should be a subchannel for every address. 1227 // We will wind up connecting to the first address for every endpoint. 1228 std::vector<std::vector<SubchannelState*>> endpoint_subchannels; 1229 endpoint_subchannels.reserve(endpoints.size()); 1230 std::vector<std::string> chosen_addresses_storage; 1231 chosen_addresses_storage.reserve(endpoints.size()); 1232 std::vector<absl::string_view> chosen_addresses; 1233 chosen_addresses.reserve(endpoints.size()); 1234 for (const EndpointAddresses& endpoint : endpoints) { 1235 endpoint_subchannels.emplace_back(); 1236 endpoint_subchannels.back().reserve(endpoint.addresses().size()); 1237 for (size_t i = 0; i < endpoint.addresses().size(); ++i) { 1238 const grpc_resolved_address& address = endpoint.addresses()[i]; 1239 std::string address_str = grpc_sockaddr_to_uri(&address).value(); 1240 auto* subchannel = FindSubchannel(address_str); 1241 EXPECT_NE(subchannel, nullptr) 1242 << address_str << "\n" 1243 << location.file() << ":" << location.line(); 1244 if (subchannel == nullptr) return nullptr; 1245 endpoint_subchannels.back().push_back(subchannel); 1246 if (i == 0) { 1247 chosen_addresses_storage.emplace_back(std::move(address_str)); 1248 chosen_addresses.emplace_back(chosen_addresses_storage.back()); 1249 } 1250 } 1251 } 1252 // We should request a connection to the first address of each endpoint, 1253 // and not to any of the subsequent addresses. 1254 for (const auto& subchannels : endpoint_subchannels) { 1255 EXPECT_TRUE(subchannels[0]->ConnectionRequested()) 1256 << location.file() << ":" << location.line(); 1257 for (size_t i = 1; i < subchannels.size(); ++i) { 1258 EXPECT_FALSE(subchannels[i]->ConnectionRequested()) 1259 << "i=" << i << "\n" 1260 << location.file() << ":" << location.line(); 1261 } 1262 } 1263 // The subchannels that we've asked to connect should report 1264 // CONNECTING state. 1265 for (size_t i = 0; i < endpoint_subchannels.size(); ++i) { 1266 endpoint_subchannels[i][0]->SetConnectivityState(GRPC_CHANNEL_CONNECTING); 1267 if (i == 0) ExpectConnectingUpdate(location); 1268 } 1269 // The connection attempts should succeed. 1270 RefCountedPtr<LoadBalancingPolicy::SubchannelPicker> picker; 1271 for (size_t i = 0; i < endpoint_subchannels.size(); ++i) { 1272 endpoint_subchannels[i][0]->SetConnectivityState(GRPC_CHANNEL_READY); 1273 if (i == 0) { 1274 // When the first subchannel becomes READY, accept any number of 1275 // CONNECTING updates with a picker that queues followed by a READY 1276 // update with a picker that repeatedly returns only the first address. 1277 picker = WaitForConnected(location); 1278 ExpectRoundRobinPicks(picker.get(), {chosen_addresses[0]}, {}, 3, 1279 location); 1280 } else { 1281 // When each subsequent subchannel becomes READY, we accept any number 1282 // of READY updates where the picker returns only the previously 1283 // connected subchannel(s) followed by a READY update where the picker 1284 // returns the previously connected subchannel(s) *and* the newly 1285 // connected subchannel. 1286 picker = WaitForRoundRobinListChange( 1287 absl::MakeSpan(chosen_addresses).subspan(0, i), 1288 absl::MakeSpan(chosen_addresses).subspan(0, i + 1), {}, 3, 1289 location); 1290 } 1291 } 1292 return picker; 1293 } 1294 1295 // A convenient override that takes a flat list of addresses, one per 1296 // endpoint. 1297 RefCountedPtr<LoadBalancingPolicy::SubchannelPicker> ExpectRoundRobinStartup( 1298 absl::Span<const absl::string_view> addresses, 1299 SourceLocation location = SourceLocation()) { 1300 return ExpectRoundRobinStartup( 1301 MakeEndpointAddressesListFromAddressList(addresses), location); 1302 } 1303 1304 // Expects zero or more picker updates, each of which returns 1305 // round-robin picks for the specified set of addresses. 1306 RefCountedPtr<LoadBalancingPolicy::SubchannelPicker> 1307 DrainRoundRobinPickerUpdates(absl::Span<const absl::string_view> addresses, 1308 SourceLocation location = SourceLocation()) { 1309 gpr_log(GPR_INFO, "Draining RR picker updates..."); 1310 RefCountedPtr<LoadBalancingPolicy::SubchannelPicker> picker; 1311 while (!helper_->QueueEmpty()) { 1312 auto update = helper_->GetNextStateUpdate(location); 1313 EXPECT_TRUE(update.has_value()) 1314 << location.file() << ":" << location.line(); 1315 if (!update.has_value()) return nullptr; 1316 EXPECT_EQ(update->state, GRPC_CHANNEL_READY) 1317 << location.file() << ":" << location.line(); 1318 if (update->state != GRPC_CHANNEL_READY) return nullptr; 1319 ExpectRoundRobinPicks(update->picker.get(), addresses, 1320 /*call_attributes=*/{}, /*num_iterations=*/3, 1321 location); 1322 picker = std::move(update->picker); 1323 } 1324 gpr_log(GPR_INFO, "Done draining RR picker updates"); 1325 return picker; 1326 } 1327 1328 // Expects zero or more CONNECTING updates. 1329 void DrainConnectingUpdates(SourceLocation location = SourceLocation()) { 1330 gpr_log(GPR_INFO, "Draining CONNECTING updates..."); 1331 while (!helper_->QueueEmpty()) { 1332 ASSERT_TRUE(ExpectConnectingUpdate(location)); 1333 } 1334 gpr_log(GPR_INFO, "Done draining CONNECTING updates"); 1335 } 1336 1337 // Triggers a connection failure for the current address for an 1338 // endpoint and expects a reconnection to the specified new address. 1339 void ExpectEndpointAddressChange( 1340 absl::Span<const absl::string_view> addresses, size_t current_index, 1341 size_t new_index, absl::AnyInvocable<void()> expect_after_disconnect, 1342 SourceLocation location = SourceLocation()) { 1343 gpr_log(GPR_INFO, 1344 "Expecting endpoint address change: addresses={%s}, " 1345 "current_index=%" PRIuPTR ", new_index=%" PRIuPTR, 1346 absl::StrJoin(addresses, ", ").c_str(), current_index, new_index); 1347 ASSERT_LT(current_index, addresses.size()); 1348 ASSERT_LT(new_index, addresses.size()); 1349 // Find all subchannels. 1350 std::vector<SubchannelState*> subchannels; 1351 subchannels.reserve(addresses.size()); 1352 for (absl::string_view address : addresses) { 1353 SubchannelState* subchannel = FindSubchannel(address); 1354 ASSERT_NE(subchannel, nullptr) 1355 << "can't find subchannel for " << address << "\n" 1356 << location.file() << ":" << location.line(); 1357 subchannels.push_back(subchannel); 1358 } 1359 // Cause current_address to become disconnected. 1360 subchannels[current_index]->SetConnectivityState(GRPC_CHANNEL_IDLE); 1361 ExpectReresolutionRequest(location); 1362 if (expect_after_disconnect != nullptr) expect_after_disconnect(); 1363 // Attempt each address in the list until we hit the desired new address. 1364 for (size_t i = 0; i < subchannels.size(); ++i) { 1365 // A connection should be requested on the subchannel for this 1366 // index, and none of the others. 1367 for (size_t j = 0; j < addresses.size(); ++j) { 1368 EXPECT_EQ(subchannels[j]->ConnectionRequested(), j == i) 1369 << location.file() << ":" << location.line(); 1370 } 1371 // Subchannel will report CONNECTING. 1372 SubchannelState* subchannel = subchannels[i]; 1373 subchannel->SetConnectivityState(GRPC_CHANNEL_CONNECTING); 1374 // If this is the one we want to stick with, it will report READY. 1375 if (i == new_index) { 1376 subchannel->SetConnectivityState(GRPC_CHANNEL_READY); 1377 break; 1378 } 1379 // Otherwise, report TF. 1380 subchannel->SetConnectivityState( 1381 GRPC_CHANNEL_TRANSIENT_FAILURE, 1382 absl::UnavailableError("connection failed")); 1383 // Report IDLE to leave it in the expected state in case the test 1384 // interacts with it again. 1385 subchannel->SetConnectivityState(GRPC_CHANNEL_IDLE); 1386 } 1387 gpr_log(GPR_INFO, "Done with endpoint address change"); 1388 } 1389 1390 // Requests a picker on picker and expects a Fail result. 1391 // The failing status is passed to check_status. 1392 void ExpectPickFail(LoadBalancingPolicy::SubchannelPicker* picker, 1393 std::function<void(const absl::Status&)> check_status, 1394 SourceLocation location = SourceLocation()) { 1395 auto pick_result = DoPick(picker); 1396 auto* fail = absl::get_if<LoadBalancingPolicy::PickResult::Fail>( 1397 &pick_result.result); 1398 ASSERT_NE(fail, nullptr) << PickResultString(pick_result) << " at " 1399 << location.file() << ":" << location.line(); 1400 check_status(fail->status); 1401 } 1402 1403 // Returns a human-readable string for a pick result. PickResultString(const LoadBalancingPolicy::PickResult & result)1404 static std::string PickResultString( 1405 const LoadBalancingPolicy::PickResult& result) { 1406 return Match( 1407 result.result, 1408 [](const LoadBalancingPolicy::PickResult::Complete& complete) { 1409 auto* subchannel = static_cast<SubchannelState::FakeSubchannel*>( 1410 complete.subchannel.get()); 1411 return absl::StrFormat( 1412 "COMPLETE{subchannel=%s, subchannel_call_tracker=%p}", 1413 subchannel->state()->address(), 1414 complete.subchannel_call_tracker.get()); 1415 }, 1416 [](const LoadBalancingPolicy::PickResult::Queue&) -> std::string { 1417 return "QUEUE{}"; 1418 }, 1419 [](const LoadBalancingPolicy::PickResult::Fail& fail) -> std::string { 1420 return absl::StrFormat("FAIL{%s}", fail.status.ToString()); 1421 }, 1422 [](const LoadBalancingPolicy::PickResult::Drop& drop) -> std::string { 1423 return absl::StrFormat("FAIL{%s}", drop.status.ToString()); 1424 }); 1425 } 1426 1427 // Returns the entry in the subchannel pool, or null if not present. 1428 SubchannelState* FindSubchannel(absl::string_view address, 1429 const ChannelArgs& args = ChannelArgs()) { 1430 SubchannelKey key(MakeAddress(address), args); 1431 auto it = subchannel_pool_.find(key); 1432 if (it == subchannel_pool_.end()) return nullptr; 1433 return &it->second; 1434 } 1435 1436 // Creates and returns an entry in the subchannel pool. 1437 // This can be used in cases where we want to test that a subchannel 1438 // already exists when the LB policy creates it (e.g., due to it being 1439 // created by another channel and shared via the global subchannel 1440 // pool, or by being created by another LB policy in this channel). 1441 SubchannelState* CreateSubchannel(absl::string_view address, 1442 const ChannelArgs& args = ChannelArgs()) { 1443 SubchannelKey key(MakeAddress(address), args); 1444 auto it = subchannel_pool_ 1445 .emplace(std::piecewise_construct, std::forward_as_tuple(key), 1446 std::forward_as_tuple(address, this)) 1447 .first; 1448 return &it->second; 1449 } 1450 WaitForWorkSerializerToFlush()1451 void WaitForWorkSerializerToFlush() { 1452 ExecCtx exec_ctx; 1453 gpr_log(GPR_INFO, "waiting for WorkSerializer to flush..."); 1454 absl::Notification notification; 1455 work_serializer_->Run([&]() { notification.Notify(); }, DEBUG_LOCATION); 1456 notification.WaitForNotification(); 1457 gpr_log(GPR_INFO, "WorkSerializer flush complete"); 1458 } 1459 IncrementTimeBy(Duration duration)1460 void IncrementTimeBy(Duration duration) { 1461 ExecCtx exec_ctx; 1462 gpr_log(GPR_INFO, "Incrementing time by %s...", 1463 duration.ToString().c_str()); 1464 fuzzing_ee_->TickForDuration(duration); 1465 gpr_log(GPR_INFO, "Done incrementing time"); 1466 // Flush WorkSerializer, in case the timer callback enqueued anything. 1467 WaitForWorkSerializerToFlush(); 1468 } 1469 1470 void SetExpectedTimerDuration( 1471 absl::optional<grpc_event_engine::experimental::EventEngine::Duration> 1472 duration, 1473 SourceLocation location = SourceLocation()) { 1474 if (duration.has_value()) { 1475 fuzzing_ee_->SetRunAfterDurationCallback( 1476 [expected = *duration, location = location]( 1477 grpc_event_engine::experimental::EventEngine::Duration duration) { 1478 EXPECT_EQ(duration, expected) 1479 << "Expected: " << expected.count() 1480 << "ns\n Actual: " << duration.count() << "ns\n" 1481 << location.file() << ":" << location.line(); 1482 }); 1483 } else { 1484 fuzzing_ee_->SetRunAfterDurationCallback(nullptr); 1485 } 1486 } 1487 1488 std::shared_ptr<grpc_event_engine::experimental::FuzzingEventEngine> 1489 fuzzing_ee_; 1490 // TODO(ctiller): this is a normal event engine, yet it gets its time measure 1491 // from fuzzing_ee_ -- results are likely to be a little funky, but seem to do 1492 // well enough for the tests we have today. 1493 // We should transition everything here to just use fuzzing_ee_, but that 1494 // needs some thought on how to Tick() at appropriate times, as there are 1495 // Notification objects buried everywhere in this code, and 1496 // WaitForNotification is deeply incompatible with a single threaded event 1497 // engine that doesn't run callbacks until its public Tick method is called. 1498 std::shared_ptr<grpc_event_engine::experimental::EventEngine> event_engine_; 1499 std::shared_ptr<WorkSerializer> work_serializer_; 1500 FakeHelper* helper_ = nullptr; 1501 std::map<SubchannelKey, SubchannelState> subchannel_pool_; 1502 OrphanablePtr<LoadBalancingPolicy> lb_policy_; 1503 const absl::string_view lb_policy_name_; 1504 const ChannelArgs channel_args_; 1505 GlobalStatsPluginRegistry::StatsPluginGroup stats_plugin_group_; 1506 std::string target_ = "dns:server.example.com"; 1507 std::string authority_ = "server.example.com"; 1508 }; 1509 1510 } // namespace testing 1511 } // namespace grpc_core 1512 1513 #endif // GRPC_TEST_CORE_CLIENT_CHANNEL_LB_POLICY_LB_POLICY_TEST_LIB_H 1514