xref: /aosp_15_r20/external/grpc-grpc/test/core/util/test_lb_policies.cc (revision cc02d7e222339f7a4f6ba5f422e6413f4bd931f2)
1 //
2 // Copyright 2018 gRPC authors.
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //     http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //
16 
17 #include "test/core/util/test_lb_policies.h"
18 
19 #include <stdint.h>
20 
21 #include <memory>
22 #include <string>
23 
24 #include "absl/status/statusor.h"
25 #include "absl/strings/string_view.h"
26 #include "absl/types/variant.h"
27 
28 #include <grpc/grpc.h>
29 #include <grpc/support/json.h>
30 #include <grpc/support/log.h>
31 
32 #include "src/core/lib/address_utils/parse_address.h"
33 #include "src/core/lib/channel/channel_args.h"
34 #include "src/core/lib/config/core_configuration.h"
35 #include "src/core/lib/gprpp/orphanable.h"
36 #include "src/core/lib/gprpp/ref_counted_ptr.h"
37 #include "src/core/lib/gprpp/status_helper.h"
38 #include "src/core/lib/gprpp/time.h"
39 #include "src/core/lib/iomgr/error.h"
40 #include "src/core/lib/iomgr/pollset_set.h"
41 #include "src/core/lib/iomgr/resolved_address.h"
42 #include "src/core/lib/json/json.h"
43 #include "src/core/lib/json/json_util.h"
44 #include "src/core/lib/uri/uri_parser.h"
45 #include "src/core/load_balancing/delegating_helper.h"
46 #include "src/core/load_balancing/lb_policy.h"
47 #include "src/core/load_balancing/lb_policy_factory.h"
48 #include "src/core/load_balancing/lb_policy_registry.h"
49 #include "src/core/load_balancing/oob_backend_metric.h"
50 #include "src/core/load_balancing/subchannel_interface.h"
51 
52 namespace grpc_core {
53 
54 namespace {
55 
56 //
57 // ForwardingLoadBalancingPolicy
58 //
59 
60 // A minimal forwarding class to avoid implementing a standalone test LB.
61 class ForwardingLoadBalancingPolicy : public LoadBalancingPolicy {
62  public:
ForwardingLoadBalancingPolicy(std::unique_ptr<ChannelControlHelper> delegating_helper,Args args,absl::string_view delegate_policy_name,intptr_t initial_refcount=1)63   ForwardingLoadBalancingPolicy(
64       std::unique_ptr<ChannelControlHelper> delegating_helper, Args args,
65       absl::string_view delegate_policy_name, intptr_t initial_refcount = 1)
66       : LoadBalancingPolicy(std::move(args), initial_refcount) {
67     Args delegate_args;
68     delegate_args.work_serializer = work_serializer();
69     delegate_args.channel_control_helper = std::move(delegating_helper);
70     delegate_args.args = channel_args();
71     delegate_ =
72         CoreConfiguration::Get().lb_policy_registry().CreateLoadBalancingPolicy(
73             delegate_policy_name, std::move(delegate_args));
74     grpc_pollset_set_add_pollset_set(delegate_->interested_parties(),
75                                      interested_parties());
76   }
77 
78   ~ForwardingLoadBalancingPolicy() override = default;
79 
UpdateLocked(UpdateArgs args)80   absl::Status UpdateLocked(UpdateArgs args) override {
81     // Use correct config for the delegate load balancing policy
82     auto config =
83         CoreConfiguration::Get().lb_policy_registry().ParseLoadBalancingConfig(
84             Json::FromArray({Json::FromObject(
85                 {{std::string(delegate_->name()), Json::FromObject({})}})}));
86     GPR_ASSERT(config.ok());
87     args.config = *config;
88     return delegate_->UpdateLocked(std::move(args));
89   }
90 
ExitIdleLocked()91   void ExitIdleLocked() override { delegate_->ExitIdleLocked(); }
92 
ResetBackoffLocked()93   void ResetBackoffLocked() override { delegate_->ResetBackoffLocked(); }
94 
95  private:
ShutdownLocked()96   void ShutdownLocked() override { delegate_.reset(); }
97 
98   OrphanablePtr<LoadBalancingPolicy> delegate_;
99 };
100 
101 //
102 // TestPickArgsLb
103 //
104 
105 constexpr absl::string_view kTestPickArgsLbPolicyName = "test_pick_args_lb";
106 
107 class TestPickArgsLb : public ForwardingLoadBalancingPolicy {
108  public:
TestPickArgsLb(Args args,TestPickArgsCallback cb,absl::string_view delegate_policy_name)109   TestPickArgsLb(Args args, TestPickArgsCallback cb,
110                  absl::string_view delegate_policy_name)
111       : ForwardingLoadBalancingPolicy(
112             std::make_unique<Helper>(RefCountedPtr<TestPickArgsLb>(this), cb),
113             std::move(args), delegate_policy_name,
114             /*initial_refcount=*/2) {}
115 
116   ~TestPickArgsLb() override = default;
117 
name() const118   absl::string_view name() const override { return kTestPickArgsLbPolicyName; }
119 
120  private:
121   class Picker : public SubchannelPicker {
122    public:
Picker(RefCountedPtr<SubchannelPicker> delegate_picker,TestPickArgsCallback cb)123     Picker(RefCountedPtr<SubchannelPicker> delegate_picker,
124            TestPickArgsCallback cb)
125         : delegate_picker_(std::move(delegate_picker)), cb_(std::move(cb)) {}
126 
Pick(PickArgs args)127     PickResult Pick(PickArgs args) override {
128       // Report args seen.
129       PickArgsSeen args_seen;
130       args_seen.path = std::string(args.path);
131       args_seen.metadata = args.initial_metadata->TestOnlyCopyToVector();
132       cb_(args_seen);
133       // Do pick.
134       return delegate_picker_->Pick(args);
135     }
136 
137    private:
138     RefCountedPtr<SubchannelPicker> delegate_picker_;
139     TestPickArgsCallback cb_;
140   };
141 
142   class Helper
143       : public ParentOwningDelegatingChannelControlHelper<TestPickArgsLb> {
144    public:
Helper(RefCountedPtr<TestPickArgsLb> parent,TestPickArgsCallback cb)145     Helper(RefCountedPtr<TestPickArgsLb> parent, TestPickArgsCallback cb)
146         : ParentOwningDelegatingChannelControlHelper(std::move(parent)),
147           cb_(std::move(cb)) {}
148 
UpdateState(grpc_connectivity_state state,const absl::Status & status,RefCountedPtr<SubchannelPicker> picker)149     void UpdateState(grpc_connectivity_state state, const absl::Status& status,
150                      RefCountedPtr<SubchannelPicker> picker) override {
151       parent_helper()->UpdateState(
152           state, status, MakeRefCounted<Picker>(std::move(picker), cb_));
153     }
154 
155    private:
156     TestPickArgsCallback cb_;
157   };
158 };
159 
160 class TestPickArgsLbConfig : public LoadBalancingPolicy::Config {
161  public:
name() const162   absl::string_view name() const override { return kTestPickArgsLbPolicyName; }
163 };
164 
165 class TestPickArgsLbFactory : public LoadBalancingPolicyFactory {
166  public:
TestPickArgsLbFactory(TestPickArgsCallback cb,absl::string_view delegate_policy_name)167   explicit TestPickArgsLbFactory(TestPickArgsCallback cb,
168                                  absl::string_view delegate_policy_name)
169       : cb_(std::move(cb)), delegate_policy_name_(delegate_policy_name) {}
170 
CreateLoadBalancingPolicy(LoadBalancingPolicy::Args args) const171   OrphanablePtr<LoadBalancingPolicy> CreateLoadBalancingPolicy(
172       LoadBalancingPolicy::Args args) const override {
173     return MakeOrphanable<TestPickArgsLb>(std::move(args), cb_,
174                                           delegate_policy_name_);
175   }
176 
name() const177   absl::string_view name() const override { return kTestPickArgsLbPolicyName; }
178 
179   absl::StatusOr<RefCountedPtr<LoadBalancingPolicy::Config>>
ParseLoadBalancingConfig(const Json &) const180   ParseLoadBalancingConfig(const Json& /*json*/) const override {
181     return MakeRefCounted<TestPickArgsLbConfig>();
182   }
183 
184  private:
185   TestPickArgsCallback cb_;
186   std::string delegate_policy_name_;
187 };
188 
189 //
190 // InterceptRecvTrailingMetadataLoadBalancingPolicy
191 //
192 
193 constexpr char kInterceptRecvTrailingMetadataLbPolicyName[] =
194     "intercept_trailing_metadata_lb";
195 
196 class InterceptRecvTrailingMetadataLoadBalancingPolicy
197     : public ForwardingLoadBalancingPolicy {
198  public:
InterceptRecvTrailingMetadataLoadBalancingPolicy(Args args,InterceptRecvTrailingMetadataCallback cb)199   InterceptRecvTrailingMetadataLoadBalancingPolicy(
200       Args args, InterceptRecvTrailingMetadataCallback cb)
201       : ForwardingLoadBalancingPolicy(
202             std::make_unique<Helper>(
203                 RefCountedPtr<InterceptRecvTrailingMetadataLoadBalancingPolicy>(
204                     this),
205                 std::move(cb)),
206             std::move(args),
207             /*delegate_policy_name=*/"pick_first",
208             /*initial_refcount=*/2) {}
209 
210   ~InterceptRecvTrailingMetadataLoadBalancingPolicy() override = default;
211 
name() const212   absl::string_view name() const override {
213     return kInterceptRecvTrailingMetadataLbPolicyName;
214   }
215 
216  private:
217   class Picker : public SubchannelPicker {
218    public:
Picker(RefCountedPtr<SubchannelPicker> delegate_picker,InterceptRecvTrailingMetadataCallback cb)219     Picker(RefCountedPtr<SubchannelPicker> delegate_picker,
220            InterceptRecvTrailingMetadataCallback cb)
221         : delegate_picker_(std::move(delegate_picker)), cb_(std::move(cb)) {}
222 
Pick(PickArgs args)223     PickResult Pick(PickArgs args) override {
224       // Do pick.
225       PickResult result = delegate_picker_->Pick(args);
226       // Intercept trailing metadata.
227       auto* complete_pick = absl::get_if<PickResult::Complete>(&result.result);
228       if (complete_pick != nullptr) {
229         complete_pick->subchannel_call_tracker =
230             std::make_unique<SubchannelCallTracker>(cb_);
231       }
232       return result;
233     }
234 
235    private:
236     RefCountedPtr<SubchannelPicker> delegate_picker_;
237     InterceptRecvTrailingMetadataCallback cb_;
238   };
239 
240   class Helper : public ParentOwningDelegatingChannelControlHelper<
241                      InterceptRecvTrailingMetadataLoadBalancingPolicy> {
242    public:
Helper(RefCountedPtr<InterceptRecvTrailingMetadataLoadBalancingPolicy> parent,InterceptRecvTrailingMetadataCallback cb)243     Helper(
244         RefCountedPtr<InterceptRecvTrailingMetadataLoadBalancingPolicy> parent,
245         InterceptRecvTrailingMetadataCallback cb)
246         : ParentOwningDelegatingChannelControlHelper(std::move(parent)),
247           cb_(std::move(cb)) {}
248 
UpdateState(grpc_connectivity_state state,const absl::Status & status,RefCountedPtr<SubchannelPicker> picker)249     void UpdateState(grpc_connectivity_state state, const absl::Status& status,
250                      RefCountedPtr<SubchannelPicker> picker) override {
251       parent_helper()->UpdateState(
252           state, status, MakeRefCounted<Picker>(std::move(picker), cb_));
253     }
254 
255    private:
256     InterceptRecvTrailingMetadataCallback cb_;
257   };
258 
259   class SubchannelCallTracker : public SubchannelCallTrackerInterface {
260    public:
SubchannelCallTracker(InterceptRecvTrailingMetadataCallback cb)261     explicit SubchannelCallTracker(InterceptRecvTrailingMetadataCallback cb)
262         : cb_(std::move(cb)) {}
263 
Start()264     void Start() override {}
265 
Finish(FinishArgs args)266     void Finish(FinishArgs args) override {
267       TrailingMetadataArgsSeen args_seen;
268       args_seen.status = args.status;
269       args_seen.backend_metric_data =
270           args.backend_metric_accessor->GetBackendMetricData();
271       args_seen.metadata = args.trailing_metadata->TestOnlyCopyToVector();
272       cb_(args_seen);
273     }
274 
275    private:
276     InterceptRecvTrailingMetadataCallback cb_;
277   };
278 };
279 
280 class InterceptTrailingConfig : public LoadBalancingPolicy::Config {
281  public:
name() const282   absl::string_view name() const override {
283     return kInterceptRecvTrailingMetadataLbPolicyName;
284   }
285 };
286 
287 class InterceptTrailingFactory : public LoadBalancingPolicyFactory {
288  public:
InterceptTrailingFactory(InterceptRecvTrailingMetadataCallback cb)289   explicit InterceptTrailingFactory(InterceptRecvTrailingMetadataCallback cb)
290       : cb_(std::move(cb)) {}
291 
CreateLoadBalancingPolicy(LoadBalancingPolicy::Args args) const292   OrphanablePtr<LoadBalancingPolicy> CreateLoadBalancingPolicy(
293       LoadBalancingPolicy::Args args) const override {
294     return MakeOrphanable<InterceptRecvTrailingMetadataLoadBalancingPolicy>(
295         std::move(args), cb_);
296   }
297 
name() const298   absl::string_view name() const override {
299     return kInterceptRecvTrailingMetadataLbPolicyName;
300   }
301 
302   absl::StatusOr<RefCountedPtr<LoadBalancingPolicy::Config>>
ParseLoadBalancingConfig(const Json &) const303   ParseLoadBalancingConfig(const Json& /*json*/) const override {
304     return MakeRefCounted<InterceptTrailingConfig>();
305   }
306 
307  private:
308   InterceptRecvTrailingMetadataCallback cb_;
309 };
310 
311 //
312 // AddressTestLoadBalancingPolicy
313 //
314 
315 constexpr char kAddressTestLbPolicyName[] = "address_test_lb";
316 
317 class AddressTestLoadBalancingPolicy : public ForwardingLoadBalancingPolicy {
318  public:
AddressTestLoadBalancingPolicy(Args args,AddressTestCallback cb)319   AddressTestLoadBalancingPolicy(Args args, AddressTestCallback cb)
320       : ForwardingLoadBalancingPolicy(
321             std::make_unique<Helper>(
322                 RefCountedPtr<AddressTestLoadBalancingPolicy>(this),
323                 std::move(cb)),
324             std::move(args),
325             /*delegate_policy_name=*/"pick_first",
326             /*initial_refcount=*/2) {}
327 
328   ~AddressTestLoadBalancingPolicy() override = default;
329 
name() const330   absl::string_view name() const override { return kAddressTestLbPolicyName; }
331 
332  private:
333   class Helper : public ParentOwningDelegatingChannelControlHelper<
334                      AddressTestLoadBalancingPolicy> {
335    public:
Helper(RefCountedPtr<AddressTestLoadBalancingPolicy> parent,AddressTestCallback cb)336     Helper(RefCountedPtr<AddressTestLoadBalancingPolicy> parent,
337            AddressTestCallback cb)
338         : ParentOwningDelegatingChannelControlHelper(std::move(parent)),
339           cb_(std::move(cb)) {}
340 
CreateSubchannel(const grpc_resolved_address & address,const ChannelArgs & per_address_args,const ChannelArgs & args)341     RefCountedPtr<SubchannelInterface> CreateSubchannel(
342         const grpc_resolved_address& address,
343         const ChannelArgs& per_address_args, const ChannelArgs& args) override {
344       cb_(EndpointAddresses(address, per_address_args));
345       return parent_helper()->CreateSubchannel(address, per_address_args, args);
346     }
347 
348    private:
349     AddressTestCallback cb_;
350   };
351 };
352 
353 class AddressTestConfig : public LoadBalancingPolicy::Config {
354  public:
name() const355   absl::string_view name() const override { return kAddressTestLbPolicyName; }
356 };
357 
358 class AddressTestFactory : public LoadBalancingPolicyFactory {
359  public:
AddressTestFactory(AddressTestCallback cb)360   explicit AddressTestFactory(AddressTestCallback cb) : cb_(std::move(cb)) {}
361 
CreateLoadBalancingPolicy(LoadBalancingPolicy::Args args) const362   OrphanablePtr<LoadBalancingPolicy> CreateLoadBalancingPolicy(
363       LoadBalancingPolicy::Args args) const override {
364     return MakeOrphanable<AddressTestLoadBalancingPolicy>(std::move(args), cb_);
365   }
366 
name() const367   absl::string_view name() const override { return kAddressTestLbPolicyName; }
368 
369   absl::StatusOr<RefCountedPtr<LoadBalancingPolicy::Config>>
ParseLoadBalancingConfig(const Json &) const370   ParseLoadBalancingConfig(const Json& /*json*/) const override {
371     return MakeRefCounted<AddressTestConfig>();
372   }
373 
374  private:
375   AddressTestCallback cb_;
376 };
377 
378 //
379 // FixedAddressLoadBalancingPolicy
380 //
381 
382 constexpr char kFixedAddressLbPolicyName[] = "fixed_address_lb";
383 
384 class FixedAddressConfig : public LoadBalancingPolicy::Config {
385  public:
FixedAddressConfig(std::string address)386   explicit FixedAddressConfig(std::string address)
387       : address_(std::move(address)) {}
388 
name() const389   absl::string_view name() const override { return kFixedAddressLbPolicyName; }
390 
address() const391   const std::string& address() const { return address_; }
392 
393  private:
394   std::string address_;
395 };
396 
397 class FixedAddressLoadBalancingPolicy : public ForwardingLoadBalancingPolicy {
398  public:
FixedAddressLoadBalancingPolicy(Args args)399   explicit FixedAddressLoadBalancingPolicy(Args args)
400       : ForwardingLoadBalancingPolicy(
401             std::make_unique<Helper>(
402                 RefCountedPtr<FixedAddressLoadBalancingPolicy>(this)),
403             std::move(args),
404             /*delegate_policy_name=*/"pick_first",
405             /*initial_refcount=*/2) {}
406 
407   ~FixedAddressLoadBalancingPolicy() override = default;
408 
name() const409   absl::string_view name() const override { return kFixedAddressLbPolicyName; }
410 
UpdateLocked(UpdateArgs args)411   absl::Status UpdateLocked(UpdateArgs args) override {
412     auto* config = static_cast<FixedAddressConfig*>(args.config.get());
413     gpr_log(GPR_INFO, "%s: update URI: %s", kFixedAddressLbPolicyName,
414             config->address().c_str());
415     auto uri = URI::Parse(config->address());
416     args.config.reset();
417     EndpointAddressesList addresses;
418     if (uri.ok()) {
419       grpc_resolved_address address;
420       GPR_ASSERT(grpc_parse_uri(*uri, &address));
421       addresses.emplace_back(address, ChannelArgs());
422     } else {
423       gpr_log(GPR_ERROR,
424               "%s: could not parse URI (%s), using empty address list",
425               kFixedAddressLbPolicyName, uri.status().ToString().c_str());
426       args.resolution_note = "no address in fixed_address_lb policy";
427     }
428     args.addresses =
429         std::make_shared<EndpointAddressesListIterator>(std::move(addresses));
430     return ForwardingLoadBalancingPolicy::UpdateLocked(std::move(args));
431   }
432 
433  private:
434   using Helper = ParentOwningDelegatingChannelControlHelper<
435       FixedAddressLoadBalancingPolicy>;
436 };
437 
438 class FixedAddressFactory : public LoadBalancingPolicyFactory {
439  public:
440   FixedAddressFactory() = default;
441 
CreateLoadBalancingPolicy(LoadBalancingPolicy::Args args) const442   OrphanablePtr<LoadBalancingPolicy> CreateLoadBalancingPolicy(
443       LoadBalancingPolicy::Args args) const override {
444     return MakeOrphanable<FixedAddressLoadBalancingPolicy>(std::move(args));
445   }
446 
name() const447   absl::string_view name() const override { return kFixedAddressLbPolicyName; }
448 
449   absl::StatusOr<RefCountedPtr<LoadBalancingPolicy::Config>>
ParseLoadBalancingConfig(const Json & json) const450   ParseLoadBalancingConfig(const Json& json) const override {
451     std::vector<grpc_error_handle> error_list;
452     std::string address;
453     ParseJsonObjectField(json.object(), "address", &address, &error_list);
454     if (!error_list.empty()) {
455       grpc_error_handle error = GRPC_ERROR_CREATE_FROM_VECTOR(
456           "errors parsing fixed_address_lb config", &error_list);
457       absl::Status status = absl::InvalidArgumentError(StatusToString(error));
458       return status;
459     }
460     return MakeRefCounted<FixedAddressConfig>(std::move(address));
461   }
462 };
463 
464 //
465 // OobBackendMetricTestLoadBalancingPolicy
466 //
467 
468 constexpr char kOobBackendMetricTestLbPolicyName[] =
469     "oob_backend_metric_test_lb";
470 
471 class OobBackendMetricTestConfig : public LoadBalancingPolicy::Config {
472  public:
name() const473   absl::string_view name() const override {
474     return kOobBackendMetricTestLbPolicyName;
475   }
476 };
477 
478 class OobBackendMetricTestLoadBalancingPolicy
479     : public ForwardingLoadBalancingPolicy {
480  public:
OobBackendMetricTestLoadBalancingPolicy(Args args,OobBackendMetricCallback cb)481   OobBackendMetricTestLoadBalancingPolicy(Args args,
482                                           OobBackendMetricCallback cb)
483       : ForwardingLoadBalancingPolicy(
484             std::make_unique<Helper>(
485                 RefCountedPtr<OobBackendMetricTestLoadBalancingPolicy>(this)),
486             std::move(args),
487             /*delegate_policy_name=*/"pick_first",
488             /*initial_refcount=*/2),
489         cb_(std::move(cb)) {}
490 
491   ~OobBackendMetricTestLoadBalancingPolicy() override = default;
492 
name() const493   absl::string_view name() const override {
494     return kOobBackendMetricTestLbPolicyName;
495   }
496 
497  private:
498   class BackendMetricWatcher : public OobBackendMetricWatcher {
499    public:
BackendMetricWatcher(EndpointAddresses address,RefCountedPtr<OobBackendMetricTestLoadBalancingPolicy> parent)500     BackendMetricWatcher(
501         EndpointAddresses address,
502         RefCountedPtr<OobBackendMetricTestLoadBalancingPolicy> parent)
503         : address_(std::move(address)), parent_(std::move(parent)) {}
504 
OnBackendMetricReport(const BackendMetricData & backend_metric_data)505     void OnBackendMetricReport(
506         const BackendMetricData& backend_metric_data) override {
507       parent_->cb_(address_, backend_metric_data);
508     }
509 
510    private:
511     EndpointAddresses address_;
512     RefCountedPtr<OobBackendMetricTestLoadBalancingPolicy> parent_;
513   };
514 
515   class Helper : public ParentOwningDelegatingChannelControlHelper<
516                      OobBackendMetricTestLoadBalancingPolicy> {
517    public:
Helper(RefCountedPtr<OobBackendMetricTestLoadBalancingPolicy> parent)518     explicit Helper(
519         RefCountedPtr<OobBackendMetricTestLoadBalancingPolicy> parent)
520         : ParentOwningDelegatingChannelControlHelper(std::move(parent)) {}
521 
CreateSubchannel(const grpc_resolved_address & address,const ChannelArgs & per_address_args,const ChannelArgs & args)522     RefCountedPtr<SubchannelInterface> CreateSubchannel(
523         const grpc_resolved_address& address,
524         const ChannelArgs& per_address_args, const ChannelArgs& args) override {
525       auto subchannel =
526           parent_helper()->CreateSubchannel(address, per_address_args, args);
527       subchannel->AddDataWatcher(MakeOobBackendMetricWatcher(
528           Duration::Seconds(1),
529           std::make_unique<BackendMetricWatcher>(
530               EndpointAddresses(address, per_address_args),
531               parent()
532                   ->RefAsSubclass<OobBackendMetricTestLoadBalancingPolicy>())));
533       return subchannel;
534     }
535   };
536 
537   OobBackendMetricCallback cb_;
538 };
539 
540 class OobBackendMetricTestFactory : public LoadBalancingPolicyFactory {
541  public:
OobBackendMetricTestFactory(OobBackendMetricCallback cb)542   explicit OobBackendMetricTestFactory(OobBackendMetricCallback cb)
543       : cb_(std::move(cb)) {}
544 
CreateLoadBalancingPolicy(LoadBalancingPolicy::Args args) const545   OrphanablePtr<LoadBalancingPolicy> CreateLoadBalancingPolicy(
546       LoadBalancingPolicy::Args args) const override {
547     return MakeOrphanable<OobBackendMetricTestLoadBalancingPolicy>(
548         std::move(args), cb_);
549   }
550 
name() const551   absl::string_view name() const override {
552     return kOobBackendMetricTestLbPolicyName;
553   }
554 
555   absl::StatusOr<RefCountedPtr<LoadBalancingPolicy::Config>>
ParseLoadBalancingConfig(const Json &) const556   ParseLoadBalancingConfig(const Json& /*json*/) const override {
557     return MakeRefCounted<OobBackendMetricTestConfig>();
558   }
559 
560  private:
561   OobBackendMetricCallback cb_;
562 };
563 
564 //
565 // FailLoadBalancingPolicy
566 //
567 
568 constexpr char kFailPolicyName[] = "fail_lb";
569 
570 class FailPolicy : public LoadBalancingPolicy {
571  public:
FailPolicy(Args args,absl::Status status,std::atomic<int> * pick_counter)572   FailPolicy(Args args, absl::Status status, std::atomic<int>* pick_counter)
573       : LoadBalancingPolicy(std::move(args)),
574         status_(std::move(status)),
575         pick_counter_(pick_counter) {}
576 
name() const577   absl::string_view name() const override { return kFailPolicyName; }
578 
UpdateLocked(UpdateArgs)579   absl::Status UpdateLocked(UpdateArgs) override {
580     channel_control_helper()->UpdateState(
581         GRPC_CHANNEL_TRANSIENT_FAILURE, status_,
582         MakeRefCounted<FailPicker>(status_, pick_counter_));
583     return absl::OkStatus();
584   }
585 
ResetBackoffLocked()586   void ResetBackoffLocked() override {}
ShutdownLocked()587   void ShutdownLocked() override {}
588 
589  private:
590   class FailPicker : public SubchannelPicker {
591    public:
FailPicker(absl::Status status,std::atomic<int> * pick_counter)592     FailPicker(absl::Status status, std::atomic<int>* pick_counter)
593         : status_(std::move(status)), pick_counter_(pick_counter) {}
594 
Pick(PickArgs)595     PickResult Pick(PickArgs /*args*/) override {
596       if (pick_counter_ != nullptr) pick_counter_->fetch_add(1);
597       return PickResult::Fail(status_);
598     }
599 
600    private:
601     absl::Status status_;
602     std::atomic<int>* pick_counter_;
603   };
604 
605   absl::Status status_;
606   std::atomic<int>* pick_counter_;
607 };
608 
609 class FailLbConfig : public LoadBalancingPolicy::Config {
610  public:
name() const611   absl::string_view name() const override { return kFailPolicyName; }
612 };
613 
614 class FailLbFactory : public LoadBalancingPolicyFactory {
615  public:
FailLbFactory(absl::Status status,std::atomic<int> * pick_counter)616   FailLbFactory(absl::Status status, std::atomic<int>* pick_counter)
617       : status_(std::move(status)), pick_counter_(pick_counter) {}
618 
CreateLoadBalancingPolicy(LoadBalancingPolicy::Args args) const619   OrphanablePtr<LoadBalancingPolicy> CreateLoadBalancingPolicy(
620       LoadBalancingPolicy::Args args) const override {
621     return MakeOrphanable<FailPolicy>(std::move(args), status_, pick_counter_);
622   }
623 
name() const624   absl::string_view name() const override { return kFailPolicyName; }
625 
626   absl::StatusOr<RefCountedPtr<LoadBalancingPolicy::Config>>
ParseLoadBalancingConfig(const Json &) const627   ParseLoadBalancingConfig(const Json& /*json*/) const override {
628     return MakeRefCounted<FailLbConfig>();
629   }
630 
631  private:
632   absl::Status status_;
633   std::atomic<int>* pick_counter_;
634 };
635 
636 //
637 // QueueOnceLoadBalancingPolicy - a load balancing policy that provides a Queue
638 // PickResult at least once, after which it delegates to PickFirst.
639 //
640 
641 constexpr char kQueueOncePolicyName[] = "queue_once";
642 
643 class QueueOnceLoadBalancingPolicy : public ForwardingLoadBalancingPolicy {
644  public:
QueueOnceLoadBalancingPolicy(Args args)645   explicit QueueOnceLoadBalancingPolicy(Args args)
646       : ForwardingLoadBalancingPolicy(
647             std::make_unique<Helper>(
648                 RefCountedPtr<QueueOnceLoadBalancingPolicy>(this)),
649             std::move(args), "pick_first",
650             /*initial_refcount=*/2) {}
651 
652   // We use the standard QueuePicker which invokes ExitIdleLocked() on the first
653   // pick.
ExitIdleLocked()654   void ExitIdleLocked() override {
655     bool needs_update = !std::exchange(seen_pick_queued_, true);
656     if (needs_update) {
657       channel_control_helper()->UpdateState(state_to_update_.state,
658                                             state_to_update_.status,
659                                             std::move(state_to_update_.picker));
660     }
661   }
662 
name() const663   absl::string_view name() const override { return kQueueOncePolicyName; }
664 
665  private:
666   class Helper : public ParentOwningDelegatingChannelControlHelper<
667                      QueueOnceLoadBalancingPolicy> {
668    public:
Helper(RefCountedPtr<QueueOnceLoadBalancingPolicy> parent)669     explicit Helper(RefCountedPtr<QueueOnceLoadBalancingPolicy> parent)
670         : ParentOwningDelegatingChannelControlHelper(std::move(parent)) {}
671 
UpdateState(grpc_connectivity_state state,const absl::Status & status,RefCountedPtr<SubchannelPicker> picker)672     void UpdateState(grpc_connectivity_state state, const absl::Status& status,
673                      RefCountedPtr<SubchannelPicker> picker) override {
674       // If we've already seen a queued pick, just propagate the update
675       // directly.
676       if (parent()->seen_pick_queued_) {
677         parent()->channel_control_helper()->UpdateState(state, status,
678                                                         std::move(picker));
679         return;
680       }
681       // Otherwise, store the update in the LB policy, to be propagated later,
682       // and return a queueing picker.
683       parent()->state_to_update_ = {state, status, std::move(picker)};
684       parent_helper()->UpdateState(
685           state, status, MakeRefCounted<QueuePicker>(parent()->Ref()));
686     }
687   };
688   struct StateToUpdate {
689     grpc_connectivity_state state;
690     absl::Status status;
691     RefCountedPtr<SubchannelPicker> picker;
692   };
693   StateToUpdate state_to_update_;
694   bool seen_pick_queued_ = false;  // Has a pick been queued yet. Only accessed
695                                    // from within the WorkSerializer.
696 };
697 
698 class QueueOnceLbConfig : public LoadBalancingPolicy::Config {
699  public:
name() const700   absl::string_view name() const override { return kQueueOncePolicyName; }
701 };
702 
703 class QueueOnceLoadBalancingPolicyFactory : public LoadBalancingPolicyFactory {
704  public:
CreateLoadBalancingPolicy(LoadBalancingPolicy::Args args) const705   OrphanablePtr<LoadBalancingPolicy> CreateLoadBalancingPolicy(
706       LoadBalancingPolicy::Args args) const override {
707     return MakeOrphanable<QueueOnceLoadBalancingPolicy>(std::move(args));
708   }
709 
name() const710   absl::string_view name() const override { return kQueueOncePolicyName; }
711 
712   absl::StatusOr<RefCountedPtr<LoadBalancingPolicy::Config>>
ParseLoadBalancingConfig(const Json &) const713   ParseLoadBalancingConfig(const Json& /*json*/) const override {
714     return MakeRefCounted<QueueOnceLbConfig>();
715   }
716 };
717 
718 }  // namespace
719 
RegisterTestPickArgsLoadBalancingPolicy(CoreConfiguration::Builder * builder,TestPickArgsCallback cb,absl::string_view delegate_policy_name)720 void RegisterTestPickArgsLoadBalancingPolicy(
721     CoreConfiguration::Builder* builder, TestPickArgsCallback cb,
722     absl::string_view delegate_policy_name) {
723   builder->lb_policy_registry()->RegisterLoadBalancingPolicyFactory(
724       std::make_unique<TestPickArgsLbFactory>(std::move(cb),
725                                               delegate_policy_name));
726 }
727 
RegisterInterceptRecvTrailingMetadataLoadBalancingPolicy(CoreConfiguration::Builder * builder,InterceptRecvTrailingMetadataCallback cb)728 void RegisterInterceptRecvTrailingMetadataLoadBalancingPolicy(
729     CoreConfiguration::Builder* builder,
730     InterceptRecvTrailingMetadataCallback cb) {
731   builder->lb_policy_registry()->RegisterLoadBalancingPolicyFactory(
732       std::make_unique<InterceptTrailingFactory>(std::move(cb)));
733 }
734 
RegisterAddressTestLoadBalancingPolicy(CoreConfiguration::Builder * builder,AddressTestCallback cb)735 void RegisterAddressTestLoadBalancingPolicy(CoreConfiguration::Builder* builder,
736                                             AddressTestCallback cb) {
737   builder->lb_policy_registry()->RegisterLoadBalancingPolicyFactory(
738       std::make_unique<AddressTestFactory>(std::move(cb)));
739 }
740 
RegisterFixedAddressLoadBalancingPolicy(CoreConfiguration::Builder * builder)741 void RegisterFixedAddressLoadBalancingPolicy(
742     CoreConfiguration::Builder* builder) {
743   builder->lb_policy_registry()->RegisterLoadBalancingPolicyFactory(
744       std::make_unique<FixedAddressFactory>());
745 }
746 
RegisterOobBackendMetricTestLoadBalancingPolicy(CoreConfiguration::Builder * builder,OobBackendMetricCallback cb)747 void RegisterOobBackendMetricTestLoadBalancingPolicy(
748     CoreConfiguration::Builder* builder, OobBackendMetricCallback cb) {
749   builder->lb_policy_registry()->RegisterLoadBalancingPolicyFactory(
750       std::make_unique<OobBackendMetricTestFactory>(std::move(cb)));
751 }
752 
RegisterFailLoadBalancingPolicy(CoreConfiguration::Builder * builder,absl::Status status,std::atomic<int> * pick_counter)753 void RegisterFailLoadBalancingPolicy(CoreConfiguration::Builder* builder,
754                                      absl::Status status,
755                                      std::atomic<int>* pick_counter) {
756   builder->lb_policy_registry()->RegisterLoadBalancingPolicyFactory(
757       std::make_unique<FailLbFactory>(std::move(status), pick_counter));
758 }
759 
RegisterQueueOnceLoadBalancingPolicy(CoreConfiguration::Builder * builder)760 void RegisterQueueOnceLoadBalancingPolicy(CoreConfiguration::Builder* builder) {
761   builder->lb_policy_registry()->RegisterLoadBalancingPolicyFactory(
762       std::make_unique<QueueOnceLoadBalancingPolicyFactory>());
763 }
764 
765 }  // namespace grpc_core
766