1 //
2 // Copyright 2018 gRPC authors.
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //
16
17 #include "test/core/util/test_lb_policies.h"
18
19 #include <stdint.h>
20
21 #include <memory>
22 #include <string>
23
24 #include "absl/status/statusor.h"
25 #include "absl/strings/string_view.h"
26 #include "absl/types/variant.h"
27
28 #include <grpc/grpc.h>
29 #include <grpc/support/json.h>
30 #include <grpc/support/log.h>
31
32 #include "src/core/lib/address_utils/parse_address.h"
33 #include "src/core/lib/channel/channel_args.h"
34 #include "src/core/lib/config/core_configuration.h"
35 #include "src/core/lib/gprpp/orphanable.h"
36 #include "src/core/lib/gprpp/ref_counted_ptr.h"
37 #include "src/core/lib/gprpp/status_helper.h"
38 #include "src/core/lib/gprpp/time.h"
39 #include "src/core/lib/iomgr/error.h"
40 #include "src/core/lib/iomgr/pollset_set.h"
41 #include "src/core/lib/iomgr/resolved_address.h"
42 #include "src/core/lib/json/json.h"
43 #include "src/core/lib/json/json_util.h"
44 #include "src/core/lib/uri/uri_parser.h"
45 #include "src/core/load_balancing/delegating_helper.h"
46 #include "src/core/load_balancing/lb_policy.h"
47 #include "src/core/load_balancing/lb_policy_factory.h"
48 #include "src/core/load_balancing/lb_policy_registry.h"
49 #include "src/core/load_balancing/oob_backend_metric.h"
50 #include "src/core/load_balancing/subchannel_interface.h"
51
52 namespace grpc_core {
53
54 namespace {
55
56 //
57 // ForwardingLoadBalancingPolicy
58 //
59
60 // A minimal forwarding class to avoid implementing a standalone test LB.
61 class ForwardingLoadBalancingPolicy : public LoadBalancingPolicy {
62 public:
ForwardingLoadBalancingPolicy(std::unique_ptr<ChannelControlHelper> delegating_helper,Args args,absl::string_view delegate_policy_name,intptr_t initial_refcount=1)63 ForwardingLoadBalancingPolicy(
64 std::unique_ptr<ChannelControlHelper> delegating_helper, Args args,
65 absl::string_view delegate_policy_name, intptr_t initial_refcount = 1)
66 : LoadBalancingPolicy(std::move(args), initial_refcount) {
67 Args delegate_args;
68 delegate_args.work_serializer = work_serializer();
69 delegate_args.channel_control_helper = std::move(delegating_helper);
70 delegate_args.args = channel_args();
71 delegate_ =
72 CoreConfiguration::Get().lb_policy_registry().CreateLoadBalancingPolicy(
73 delegate_policy_name, std::move(delegate_args));
74 grpc_pollset_set_add_pollset_set(delegate_->interested_parties(),
75 interested_parties());
76 }
77
78 ~ForwardingLoadBalancingPolicy() override = default;
79
UpdateLocked(UpdateArgs args)80 absl::Status UpdateLocked(UpdateArgs args) override {
81 // Use correct config for the delegate load balancing policy
82 auto config =
83 CoreConfiguration::Get().lb_policy_registry().ParseLoadBalancingConfig(
84 Json::FromArray({Json::FromObject(
85 {{std::string(delegate_->name()), Json::FromObject({})}})}));
86 GPR_ASSERT(config.ok());
87 args.config = *config;
88 return delegate_->UpdateLocked(std::move(args));
89 }
90
ExitIdleLocked()91 void ExitIdleLocked() override { delegate_->ExitIdleLocked(); }
92
ResetBackoffLocked()93 void ResetBackoffLocked() override { delegate_->ResetBackoffLocked(); }
94
95 private:
ShutdownLocked()96 void ShutdownLocked() override { delegate_.reset(); }
97
98 OrphanablePtr<LoadBalancingPolicy> delegate_;
99 };
100
101 //
102 // TestPickArgsLb
103 //
104
105 constexpr absl::string_view kTestPickArgsLbPolicyName = "test_pick_args_lb";
106
107 class TestPickArgsLb : public ForwardingLoadBalancingPolicy {
108 public:
TestPickArgsLb(Args args,TestPickArgsCallback cb,absl::string_view delegate_policy_name)109 TestPickArgsLb(Args args, TestPickArgsCallback cb,
110 absl::string_view delegate_policy_name)
111 : ForwardingLoadBalancingPolicy(
112 std::make_unique<Helper>(RefCountedPtr<TestPickArgsLb>(this), cb),
113 std::move(args), delegate_policy_name,
114 /*initial_refcount=*/2) {}
115
116 ~TestPickArgsLb() override = default;
117
name() const118 absl::string_view name() const override { return kTestPickArgsLbPolicyName; }
119
120 private:
121 class Picker : public SubchannelPicker {
122 public:
Picker(RefCountedPtr<SubchannelPicker> delegate_picker,TestPickArgsCallback cb)123 Picker(RefCountedPtr<SubchannelPicker> delegate_picker,
124 TestPickArgsCallback cb)
125 : delegate_picker_(std::move(delegate_picker)), cb_(std::move(cb)) {}
126
Pick(PickArgs args)127 PickResult Pick(PickArgs args) override {
128 // Report args seen.
129 PickArgsSeen args_seen;
130 args_seen.path = std::string(args.path);
131 args_seen.metadata = args.initial_metadata->TestOnlyCopyToVector();
132 cb_(args_seen);
133 // Do pick.
134 return delegate_picker_->Pick(args);
135 }
136
137 private:
138 RefCountedPtr<SubchannelPicker> delegate_picker_;
139 TestPickArgsCallback cb_;
140 };
141
142 class Helper
143 : public ParentOwningDelegatingChannelControlHelper<TestPickArgsLb> {
144 public:
Helper(RefCountedPtr<TestPickArgsLb> parent,TestPickArgsCallback cb)145 Helper(RefCountedPtr<TestPickArgsLb> parent, TestPickArgsCallback cb)
146 : ParentOwningDelegatingChannelControlHelper(std::move(parent)),
147 cb_(std::move(cb)) {}
148
UpdateState(grpc_connectivity_state state,const absl::Status & status,RefCountedPtr<SubchannelPicker> picker)149 void UpdateState(grpc_connectivity_state state, const absl::Status& status,
150 RefCountedPtr<SubchannelPicker> picker) override {
151 parent_helper()->UpdateState(
152 state, status, MakeRefCounted<Picker>(std::move(picker), cb_));
153 }
154
155 private:
156 TestPickArgsCallback cb_;
157 };
158 };
159
160 class TestPickArgsLbConfig : public LoadBalancingPolicy::Config {
161 public:
name() const162 absl::string_view name() const override { return kTestPickArgsLbPolicyName; }
163 };
164
165 class TestPickArgsLbFactory : public LoadBalancingPolicyFactory {
166 public:
TestPickArgsLbFactory(TestPickArgsCallback cb,absl::string_view delegate_policy_name)167 explicit TestPickArgsLbFactory(TestPickArgsCallback cb,
168 absl::string_view delegate_policy_name)
169 : cb_(std::move(cb)), delegate_policy_name_(delegate_policy_name) {}
170
CreateLoadBalancingPolicy(LoadBalancingPolicy::Args args) const171 OrphanablePtr<LoadBalancingPolicy> CreateLoadBalancingPolicy(
172 LoadBalancingPolicy::Args args) const override {
173 return MakeOrphanable<TestPickArgsLb>(std::move(args), cb_,
174 delegate_policy_name_);
175 }
176
name() const177 absl::string_view name() const override { return kTestPickArgsLbPolicyName; }
178
179 absl::StatusOr<RefCountedPtr<LoadBalancingPolicy::Config>>
ParseLoadBalancingConfig(const Json &) const180 ParseLoadBalancingConfig(const Json& /*json*/) const override {
181 return MakeRefCounted<TestPickArgsLbConfig>();
182 }
183
184 private:
185 TestPickArgsCallback cb_;
186 std::string delegate_policy_name_;
187 };
188
189 //
190 // InterceptRecvTrailingMetadataLoadBalancingPolicy
191 //
192
193 constexpr char kInterceptRecvTrailingMetadataLbPolicyName[] =
194 "intercept_trailing_metadata_lb";
195
196 class InterceptRecvTrailingMetadataLoadBalancingPolicy
197 : public ForwardingLoadBalancingPolicy {
198 public:
InterceptRecvTrailingMetadataLoadBalancingPolicy(Args args,InterceptRecvTrailingMetadataCallback cb)199 InterceptRecvTrailingMetadataLoadBalancingPolicy(
200 Args args, InterceptRecvTrailingMetadataCallback cb)
201 : ForwardingLoadBalancingPolicy(
202 std::make_unique<Helper>(
203 RefCountedPtr<InterceptRecvTrailingMetadataLoadBalancingPolicy>(
204 this),
205 std::move(cb)),
206 std::move(args),
207 /*delegate_policy_name=*/"pick_first",
208 /*initial_refcount=*/2) {}
209
210 ~InterceptRecvTrailingMetadataLoadBalancingPolicy() override = default;
211
name() const212 absl::string_view name() const override {
213 return kInterceptRecvTrailingMetadataLbPolicyName;
214 }
215
216 private:
217 class Picker : public SubchannelPicker {
218 public:
Picker(RefCountedPtr<SubchannelPicker> delegate_picker,InterceptRecvTrailingMetadataCallback cb)219 Picker(RefCountedPtr<SubchannelPicker> delegate_picker,
220 InterceptRecvTrailingMetadataCallback cb)
221 : delegate_picker_(std::move(delegate_picker)), cb_(std::move(cb)) {}
222
Pick(PickArgs args)223 PickResult Pick(PickArgs args) override {
224 // Do pick.
225 PickResult result = delegate_picker_->Pick(args);
226 // Intercept trailing metadata.
227 auto* complete_pick = absl::get_if<PickResult::Complete>(&result.result);
228 if (complete_pick != nullptr) {
229 complete_pick->subchannel_call_tracker =
230 std::make_unique<SubchannelCallTracker>(cb_);
231 }
232 return result;
233 }
234
235 private:
236 RefCountedPtr<SubchannelPicker> delegate_picker_;
237 InterceptRecvTrailingMetadataCallback cb_;
238 };
239
240 class Helper : public ParentOwningDelegatingChannelControlHelper<
241 InterceptRecvTrailingMetadataLoadBalancingPolicy> {
242 public:
Helper(RefCountedPtr<InterceptRecvTrailingMetadataLoadBalancingPolicy> parent,InterceptRecvTrailingMetadataCallback cb)243 Helper(
244 RefCountedPtr<InterceptRecvTrailingMetadataLoadBalancingPolicy> parent,
245 InterceptRecvTrailingMetadataCallback cb)
246 : ParentOwningDelegatingChannelControlHelper(std::move(parent)),
247 cb_(std::move(cb)) {}
248
UpdateState(grpc_connectivity_state state,const absl::Status & status,RefCountedPtr<SubchannelPicker> picker)249 void UpdateState(grpc_connectivity_state state, const absl::Status& status,
250 RefCountedPtr<SubchannelPicker> picker) override {
251 parent_helper()->UpdateState(
252 state, status, MakeRefCounted<Picker>(std::move(picker), cb_));
253 }
254
255 private:
256 InterceptRecvTrailingMetadataCallback cb_;
257 };
258
259 class SubchannelCallTracker : public SubchannelCallTrackerInterface {
260 public:
SubchannelCallTracker(InterceptRecvTrailingMetadataCallback cb)261 explicit SubchannelCallTracker(InterceptRecvTrailingMetadataCallback cb)
262 : cb_(std::move(cb)) {}
263
Start()264 void Start() override {}
265
Finish(FinishArgs args)266 void Finish(FinishArgs args) override {
267 TrailingMetadataArgsSeen args_seen;
268 args_seen.status = args.status;
269 args_seen.backend_metric_data =
270 args.backend_metric_accessor->GetBackendMetricData();
271 args_seen.metadata = args.trailing_metadata->TestOnlyCopyToVector();
272 cb_(args_seen);
273 }
274
275 private:
276 InterceptRecvTrailingMetadataCallback cb_;
277 };
278 };
279
280 class InterceptTrailingConfig : public LoadBalancingPolicy::Config {
281 public:
name() const282 absl::string_view name() const override {
283 return kInterceptRecvTrailingMetadataLbPolicyName;
284 }
285 };
286
287 class InterceptTrailingFactory : public LoadBalancingPolicyFactory {
288 public:
InterceptTrailingFactory(InterceptRecvTrailingMetadataCallback cb)289 explicit InterceptTrailingFactory(InterceptRecvTrailingMetadataCallback cb)
290 : cb_(std::move(cb)) {}
291
CreateLoadBalancingPolicy(LoadBalancingPolicy::Args args) const292 OrphanablePtr<LoadBalancingPolicy> CreateLoadBalancingPolicy(
293 LoadBalancingPolicy::Args args) const override {
294 return MakeOrphanable<InterceptRecvTrailingMetadataLoadBalancingPolicy>(
295 std::move(args), cb_);
296 }
297
name() const298 absl::string_view name() const override {
299 return kInterceptRecvTrailingMetadataLbPolicyName;
300 }
301
302 absl::StatusOr<RefCountedPtr<LoadBalancingPolicy::Config>>
ParseLoadBalancingConfig(const Json &) const303 ParseLoadBalancingConfig(const Json& /*json*/) const override {
304 return MakeRefCounted<InterceptTrailingConfig>();
305 }
306
307 private:
308 InterceptRecvTrailingMetadataCallback cb_;
309 };
310
311 //
312 // AddressTestLoadBalancingPolicy
313 //
314
315 constexpr char kAddressTestLbPolicyName[] = "address_test_lb";
316
317 class AddressTestLoadBalancingPolicy : public ForwardingLoadBalancingPolicy {
318 public:
AddressTestLoadBalancingPolicy(Args args,AddressTestCallback cb)319 AddressTestLoadBalancingPolicy(Args args, AddressTestCallback cb)
320 : ForwardingLoadBalancingPolicy(
321 std::make_unique<Helper>(
322 RefCountedPtr<AddressTestLoadBalancingPolicy>(this),
323 std::move(cb)),
324 std::move(args),
325 /*delegate_policy_name=*/"pick_first",
326 /*initial_refcount=*/2) {}
327
328 ~AddressTestLoadBalancingPolicy() override = default;
329
name() const330 absl::string_view name() const override { return kAddressTestLbPolicyName; }
331
332 private:
333 class Helper : public ParentOwningDelegatingChannelControlHelper<
334 AddressTestLoadBalancingPolicy> {
335 public:
Helper(RefCountedPtr<AddressTestLoadBalancingPolicy> parent,AddressTestCallback cb)336 Helper(RefCountedPtr<AddressTestLoadBalancingPolicy> parent,
337 AddressTestCallback cb)
338 : ParentOwningDelegatingChannelControlHelper(std::move(parent)),
339 cb_(std::move(cb)) {}
340
CreateSubchannel(const grpc_resolved_address & address,const ChannelArgs & per_address_args,const ChannelArgs & args)341 RefCountedPtr<SubchannelInterface> CreateSubchannel(
342 const grpc_resolved_address& address,
343 const ChannelArgs& per_address_args, const ChannelArgs& args) override {
344 cb_(EndpointAddresses(address, per_address_args));
345 return parent_helper()->CreateSubchannel(address, per_address_args, args);
346 }
347
348 private:
349 AddressTestCallback cb_;
350 };
351 };
352
353 class AddressTestConfig : public LoadBalancingPolicy::Config {
354 public:
name() const355 absl::string_view name() const override { return kAddressTestLbPolicyName; }
356 };
357
358 class AddressTestFactory : public LoadBalancingPolicyFactory {
359 public:
AddressTestFactory(AddressTestCallback cb)360 explicit AddressTestFactory(AddressTestCallback cb) : cb_(std::move(cb)) {}
361
CreateLoadBalancingPolicy(LoadBalancingPolicy::Args args) const362 OrphanablePtr<LoadBalancingPolicy> CreateLoadBalancingPolicy(
363 LoadBalancingPolicy::Args args) const override {
364 return MakeOrphanable<AddressTestLoadBalancingPolicy>(std::move(args), cb_);
365 }
366
name() const367 absl::string_view name() const override { return kAddressTestLbPolicyName; }
368
369 absl::StatusOr<RefCountedPtr<LoadBalancingPolicy::Config>>
ParseLoadBalancingConfig(const Json &) const370 ParseLoadBalancingConfig(const Json& /*json*/) const override {
371 return MakeRefCounted<AddressTestConfig>();
372 }
373
374 private:
375 AddressTestCallback cb_;
376 };
377
378 //
379 // FixedAddressLoadBalancingPolicy
380 //
381
382 constexpr char kFixedAddressLbPolicyName[] = "fixed_address_lb";
383
384 class FixedAddressConfig : public LoadBalancingPolicy::Config {
385 public:
FixedAddressConfig(std::string address)386 explicit FixedAddressConfig(std::string address)
387 : address_(std::move(address)) {}
388
name() const389 absl::string_view name() const override { return kFixedAddressLbPolicyName; }
390
address() const391 const std::string& address() const { return address_; }
392
393 private:
394 std::string address_;
395 };
396
397 class FixedAddressLoadBalancingPolicy : public ForwardingLoadBalancingPolicy {
398 public:
FixedAddressLoadBalancingPolicy(Args args)399 explicit FixedAddressLoadBalancingPolicy(Args args)
400 : ForwardingLoadBalancingPolicy(
401 std::make_unique<Helper>(
402 RefCountedPtr<FixedAddressLoadBalancingPolicy>(this)),
403 std::move(args),
404 /*delegate_policy_name=*/"pick_first",
405 /*initial_refcount=*/2) {}
406
407 ~FixedAddressLoadBalancingPolicy() override = default;
408
name() const409 absl::string_view name() const override { return kFixedAddressLbPolicyName; }
410
UpdateLocked(UpdateArgs args)411 absl::Status UpdateLocked(UpdateArgs args) override {
412 auto* config = static_cast<FixedAddressConfig*>(args.config.get());
413 gpr_log(GPR_INFO, "%s: update URI: %s", kFixedAddressLbPolicyName,
414 config->address().c_str());
415 auto uri = URI::Parse(config->address());
416 args.config.reset();
417 EndpointAddressesList addresses;
418 if (uri.ok()) {
419 grpc_resolved_address address;
420 GPR_ASSERT(grpc_parse_uri(*uri, &address));
421 addresses.emplace_back(address, ChannelArgs());
422 } else {
423 gpr_log(GPR_ERROR,
424 "%s: could not parse URI (%s), using empty address list",
425 kFixedAddressLbPolicyName, uri.status().ToString().c_str());
426 args.resolution_note = "no address in fixed_address_lb policy";
427 }
428 args.addresses =
429 std::make_shared<EndpointAddressesListIterator>(std::move(addresses));
430 return ForwardingLoadBalancingPolicy::UpdateLocked(std::move(args));
431 }
432
433 private:
434 using Helper = ParentOwningDelegatingChannelControlHelper<
435 FixedAddressLoadBalancingPolicy>;
436 };
437
438 class FixedAddressFactory : public LoadBalancingPolicyFactory {
439 public:
440 FixedAddressFactory() = default;
441
CreateLoadBalancingPolicy(LoadBalancingPolicy::Args args) const442 OrphanablePtr<LoadBalancingPolicy> CreateLoadBalancingPolicy(
443 LoadBalancingPolicy::Args args) const override {
444 return MakeOrphanable<FixedAddressLoadBalancingPolicy>(std::move(args));
445 }
446
name() const447 absl::string_view name() const override { return kFixedAddressLbPolicyName; }
448
449 absl::StatusOr<RefCountedPtr<LoadBalancingPolicy::Config>>
ParseLoadBalancingConfig(const Json & json) const450 ParseLoadBalancingConfig(const Json& json) const override {
451 std::vector<grpc_error_handle> error_list;
452 std::string address;
453 ParseJsonObjectField(json.object(), "address", &address, &error_list);
454 if (!error_list.empty()) {
455 grpc_error_handle error = GRPC_ERROR_CREATE_FROM_VECTOR(
456 "errors parsing fixed_address_lb config", &error_list);
457 absl::Status status = absl::InvalidArgumentError(StatusToString(error));
458 return status;
459 }
460 return MakeRefCounted<FixedAddressConfig>(std::move(address));
461 }
462 };
463
464 //
465 // OobBackendMetricTestLoadBalancingPolicy
466 //
467
468 constexpr char kOobBackendMetricTestLbPolicyName[] =
469 "oob_backend_metric_test_lb";
470
471 class OobBackendMetricTestConfig : public LoadBalancingPolicy::Config {
472 public:
name() const473 absl::string_view name() const override {
474 return kOobBackendMetricTestLbPolicyName;
475 }
476 };
477
478 class OobBackendMetricTestLoadBalancingPolicy
479 : public ForwardingLoadBalancingPolicy {
480 public:
OobBackendMetricTestLoadBalancingPolicy(Args args,OobBackendMetricCallback cb)481 OobBackendMetricTestLoadBalancingPolicy(Args args,
482 OobBackendMetricCallback cb)
483 : ForwardingLoadBalancingPolicy(
484 std::make_unique<Helper>(
485 RefCountedPtr<OobBackendMetricTestLoadBalancingPolicy>(this)),
486 std::move(args),
487 /*delegate_policy_name=*/"pick_first",
488 /*initial_refcount=*/2),
489 cb_(std::move(cb)) {}
490
491 ~OobBackendMetricTestLoadBalancingPolicy() override = default;
492
name() const493 absl::string_view name() const override {
494 return kOobBackendMetricTestLbPolicyName;
495 }
496
497 private:
498 class BackendMetricWatcher : public OobBackendMetricWatcher {
499 public:
BackendMetricWatcher(EndpointAddresses address,RefCountedPtr<OobBackendMetricTestLoadBalancingPolicy> parent)500 BackendMetricWatcher(
501 EndpointAddresses address,
502 RefCountedPtr<OobBackendMetricTestLoadBalancingPolicy> parent)
503 : address_(std::move(address)), parent_(std::move(parent)) {}
504
OnBackendMetricReport(const BackendMetricData & backend_metric_data)505 void OnBackendMetricReport(
506 const BackendMetricData& backend_metric_data) override {
507 parent_->cb_(address_, backend_metric_data);
508 }
509
510 private:
511 EndpointAddresses address_;
512 RefCountedPtr<OobBackendMetricTestLoadBalancingPolicy> parent_;
513 };
514
515 class Helper : public ParentOwningDelegatingChannelControlHelper<
516 OobBackendMetricTestLoadBalancingPolicy> {
517 public:
Helper(RefCountedPtr<OobBackendMetricTestLoadBalancingPolicy> parent)518 explicit Helper(
519 RefCountedPtr<OobBackendMetricTestLoadBalancingPolicy> parent)
520 : ParentOwningDelegatingChannelControlHelper(std::move(parent)) {}
521
CreateSubchannel(const grpc_resolved_address & address,const ChannelArgs & per_address_args,const ChannelArgs & args)522 RefCountedPtr<SubchannelInterface> CreateSubchannel(
523 const grpc_resolved_address& address,
524 const ChannelArgs& per_address_args, const ChannelArgs& args) override {
525 auto subchannel =
526 parent_helper()->CreateSubchannel(address, per_address_args, args);
527 subchannel->AddDataWatcher(MakeOobBackendMetricWatcher(
528 Duration::Seconds(1),
529 std::make_unique<BackendMetricWatcher>(
530 EndpointAddresses(address, per_address_args),
531 parent()
532 ->RefAsSubclass<OobBackendMetricTestLoadBalancingPolicy>())));
533 return subchannel;
534 }
535 };
536
537 OobBackendMetricCallback cb_;
538 };
539
540 class OobBackendMetricTestFactory : public LoadBalancingPolicyFactory {
541 public:
OobBackendMetricTestFactory(OobBackendMetricCallback cb)542 explicit OobBackendMetricTestFactory(OobBackendMetricCallback cb)
543 : cb_(std::move(cb)) {}
544
CreateLoadBalancingPolicy(LoadBalancingPolicy::Args args) const545 OrphanablePtr<LoadBalancingPolicy> CreateLoadBalancingPolicy(
546 LoadBalancingPolicy::Args args) const override {
547 return MakeOrphanable<OobBackendMetricTestLoadBalancingPolicy>(
548 std::move(args), cb_);
549 }
550
name() const551 absl::string_view name() const override {
552 return kOobBackendMetricTestLbPolicyName;
553 }
554
555 absl::StatusOr<RefCountedPtr<LoadBalancingPolicy::Config>>
ParseLoadBalancingConfig(const Json &) const556 ParseLoadBalancingConfig(const Json& /*json*/) const override {
557 return MakeRefCounted<OobBackendMetricTestConfig>();
558 }
559
560 private:
561 OobBackendMetricCallback cb_;
562 };
563
564 //
565 // FailLoadBalancingPolicy
566 //
567
568 constexpr char kFailPolicyName[] = "fail_lb";
569
570 class FailPolicy : public LoadBalancingPolicy {
571 public:
FailPolicy(Args args,absl::Status status,std::atomic<int> * pick_counter)572 FailPolicy(Args args, absl::Status status, std::atomic<int>* pick_counter)
573 : LoadBalancingPolicy(std::move(args)),
574 status_(std::move(status)),
575 pick_counter_(pick_counter) {}
576
name() const577 absl::string_view name() const override { return kFailPolicyName; }
578
UpdateLocked(UpdateArgs)579 absl::Status UpdateLocked(UpdateArgs) override {
580 channel_control_helper()->UpdateState(
581 GRPC_CHANNEL_TRANSIENT_FAILURE, status_,
582 MakeRefCounted<FailPicker>(status_, pick_counter_));
583 return absl::OkStatus();
584 }
585
ResetBackoffLocked()586 void ResetBackoffLocked() override {}
ShutdownLocked()587 void ShutdownLocked() override {}
588
589 private:
590 class FailPicker : public SubchannelPicker {
591 public:
FailPicker(absl::Status status,std::atomic<int> * pick_counter)592 FailPicker(absl::Status status, std::atomic<int>* pick_counter)
593 : status_(std::move(status)), pick_counter_(pick_counter) {}
594
Pick(PickArgs)595 PickResult Pick(PickArgs /*args*/) override {
596 if (pick_counter_ != nullptr) pick_counter_->fetch_add(1);
597 return PickResult::Fail(status_);
598 }
599
600 private:
601 absl::Status status_;
602 std::atomic<int>* pick_counter_;
603 };
604
605 absl::Status status_;
606 std::atomic<int>* pick_counter_;
607 };
608
609 class FailLbConfig : public LoadBalancingPolicy::Config {
610 public:
name() const611 absl::string_view name() const override { return kFailPolicyName; }
612 };
613
614 class FailLbFactory : public LoadBalancingPolicyFactory {
615 public:
FailLbFactory(absl::Status status,std::atomic<int> * pick_counter)616 FailLbFactory(absl::Status status, std::atomic<int>* pick_counter)
617 : status_(std::move(status)), pick_counter_(pick_counter) {}
618
CreateLoadBalancingPolicy(LoadBalancingPolicy::Args args) const619 OrphanablePtr<LoadBalancingPolicy> CreateLoadBalancingPolicy(
620 LoadBalancingPolicy::Args args) const override {
621 return MakeOrphanable<FailPolicy>(std::move(args), status_, pick_counter_);
622 }
623
name() const624 absl::string_view name() const override { return kFailPolicyName; }
625
626 absl::StatusOr<RefCountedPtr<LoadBalancingPolicy::Config>>
ParseLoadBalancingConfig(const Json &) const627 ParseLoadBalancingConfig(const Json& /*json*/) const override {
628 return MakeRefCounted<FailLbConfig>();
629 }
630
631 private:
632 absl::Status status_;
633 std::atomic<int>* pick_counter_;
634 };
635
636 //
637 // QueueOnceLoadBalancingPolicy - a load balancing policy that provides a Queue
638 // PickResult at least once, after which it delegates to PickFirst.
639 //
640
641 constexpr char kQueueOncePolicyName[] = "queue_once";
642
643 class QueueOnceLoadBalancingPolicy : public ForwardingLoadBalancingPolicy {
644 public:
QueueOnceLoadBalancingPolicy(Args args)645 explicit QueueOnceLoadBalancingPolicy(Args args)
646 : ForwardingLoadBalancingPolicy(
647 std::make_unique<Helper>(
648 RefCountedPtr<QueueOnceLoadBalancingPolicy>(this)),
649 std::move(args), "pick_first",
650 /*initial_refcount=*/2) {}
651
652 // We use the standard QueuePicker which invokes ExitIdleLocked() on the first
653 // pick.
ExitIdleLocked()654 void ExitIdleLocked() override {
655 bool needs_update = !std::exchange(seen_pick_queued_, true);
656 if (needs_update) {
657 channel_control_helper()->UpdateState(state_to_update_.state,
658 state_to_update_.status,
659 std::move(state_to_update_.picker));
660 }
661 }
662
name() const663 absl::string_view name() const override { return kQueueOncePolicyName; }
664
665 private:
666 class Helper : public ParentOwningDelegatingChannelControlHelper<
667 QueueOnceLoadBalancingPolicy> {
668 public:
Helper(RefCountedPtr<QueueOnceLoadBalancingPolicy> parent)669 explicit Helper(RefCountedPtr<QueueOnceLoadBalancingPolicy> parent)
670 : ParentOwningDelegatingChannelControlHelper(std::move(parent)) {}
671
UpdateState(grpc_connectivity_state state,const absl::Status & status,RefCountedPtr<SubchannelPicker> picker)672 void UpdateState(grpc_connectivity_state state, const absl::Status& status,
673 RefCountedPtr<SubchannelPicker> picker) override {
674 // If we've already seen a queued pick, just propagate the update
675 // directly.
676 if (parent()->seen_pick_queued_) {
677 parent()->channel_control_helper()->UpdateState(state, status,
678 std::move(picker));
679 return;
680 }
681 // Otherwise, store the update in the LB policy, to be propagated later,
682 // and return a queueing picker.
683 parent()->state_to_update_ = {state, status, std::move(picker)};
684 parent_helper()->UpdateState(
685 state, status, MakeRefCounted<QueuePicker>(parent()->Ref()));
686 }
687 };
688 struct StateToUpdate {
689 grpc_connectivity_state state;
690 absl::Status status;
691 RefCountedPtr<SubchannelPicker> picker;
692 };
693 StateToUpdate state_to_update_;
694 bool seen_pick_queued_ = false; // Has a pick been queued yet. Only accessed
695 // from within the WorkSerializer.
696 };
697
698 class QueueOnceLbConfig : public LoadBalancingPolicy::Config {
699 public:
name() const700 absl::string_view name() const override { return kQueueOncePolicyName; }
701 };
702
703 class QueueOnceLoadBalancingPolicyFactory : public LoadBalancingPolicyFactory {
704 public:
CreateLoadBalancingPolicy(LoadBalancingPolicy::Args args) const705 OrphanablePtr<LoadBalancingPolicy> CreateLoadBalancingPolicy(
706 LoadBalancingPolicy::Args args) const override {
707 return MakeOrphanable<QueueOnceLoadBalancingPolicy>(std::move(args));
708 }
709
name() const710 absl::string_view name() const override { return kQueueOncePolicyName; }
711
712 absl::StatusOr<RefCountedPtr<LoadBalancingPolicy::Config>>
ParseLoadBalancingConfig(const Json &) const713 ParseLoadBalancingConfig(const Json& /*json*/) const override {
714 return MakeRefCounted<QueueOnceLbConfig>();
715 }
716 };
717
718 } // namespace
719
RegisterTestPickArgsLoadBalancingPolicy(CoreConfiguration::Builder * builder,TestPickArgsCallback cb,absl::string_view delegate_policy_name)720 void RegisterTestPickArgsLoadBalancingPolicy(
721 CoreConfiguration::Builder* builder, TestPickArgsCallback cb,
722 absl::string_view delegate_policy_name) {
723 builder->lb_policy_registry()->RegisterLoadBalancingPolicyFactory(
724 std::make_unique<TestPickArgsLbFactory>(std::move(cb),
725 delegate_policy_name));
726 }
727
RegisterInterceptRecvTrailingMetadataLoadBalancingPolicy(CoreConfiguration::Builder * builder,InterceptRecvTrailingMetadataCallback cb)728 void RegisterInterceptRecvTrailingMetadataLoadBalancingPolicy(
729 CoreConfiguration::Builder* builder,
730 InterceptRecvTrailingMetadataCallback cb) {
731 builder->lb_policy_registry()->RegisterLoadBalancingPolicyFactory(
732 std::make_unique<InterceptTrailingFactory>(std::move(cb)));
733 }
734
RegisterAddressTestLoadBalancingPolicy(CoreConfiguration::Builder * builder,AddressTestCallback cb)735 void RegisterAddressTestLoadBalancingPolicy(CoreConfiguration::Builder* builder,
736 AddressTestCallback cb) {
737 builder->lb_policy_registry()->RegisterLoadBalancingPolicyFactory(
738 std::make_unique<AddressTestFactory>(std::move(cb)));
739 }
740
RegisterFixedAddressLoadBalancingPolicy(CoreConfiguration::Builder * builder)741 void RegisterFixedAddressLoadBalancingPolicy(
742 CoreConfiguration::Builder* builder) {
743 builder->lb_policy_registry()->RegisterLoadBalancingPolicyFactory(
744 std::make_unique<FixedAddressFactory>());
745 }
746
RegisterOobBackendMetricTestLoadBalancingPolicy(CoreConfiguration::Builder * builder,OobBackendMetricCallback cb)747 void RegisterOobBackendMetricTestLoadBalancingPolicy(
748 CoreConfiguration::Builder* builder, OobBackendMetricCallback cb) {
749 builder->lb_policy_registry()->RegisterLoadBalancingPolicyFactory(
750 std::make_unique<OobBackendMetricTestFactory>(std::move(cb)));
751 }
752
RegisterFailLoadBalancingPolicy(CoreConfiguration::Builder * builder,absl::Status status,std::atomic<int> * pick_counter)753 void RegisterFailLoadBalancingPolicy(CoreConfiguration::Builder* builder,
754 absl::Status status,
755 std::atomic<int>* pick_counter) {
756 builder->lb_policy_registry()->RegisterLoadBalancingPolicyFactory(
757 std::make_unique<FailLbFactory>(std::move(status), pick_counter));
758 }
759
RegisterQueueOnceLoadBalancingPolicy(CoreConfiguration::Builder * builder)760 void RegisterQueueOnceLoadBalancingPolicy(CoreConfiguration::Builder* builder) {
761 builder->lb_policy_registry()->RegisterLoadBalancingPolicyFactory(
762 std::make_unique<QueueOnceLoadBalancingPolicyFactory>());
763 }
764
765 } // namespace grpc_core
766