1 // Copyright (c) 2022 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #ifndef QUICHE_QUIC_LOAD_BALANCER_LOAD_BALANCER_SERVER_ID_MAP_H_
6 #define QUICHE_QUIC_LOAD_BALANCER_LOAD_BALANCER_SERVER_ID_MAP_H_
7
8 #include <cstdint>
9 #include <memory>
10 #include <optional>
11
12 #include "absl/container/flat_hash_map.h"
13 #include "quiche/quic/load_balancer/load_balancer_server_id.h"
14 #include "quiche/quic/platform/api/quic_bug_tracker.h"
15
16 namespace quic {
17
18 // This class wraps an absl::flat_hash_map which associates server IDs to an
19 // arbitrary type T. It validates that all server ids are of the same fixed
20 // length. This might be used by a load balancer to connect a server ID with a
21 // pool member data structure.
22 template <typename T>
23 class QUIC_EXPORT_PRIVATE LoadBalancerServerIdMap {
24 public:
25 // Returns a newly created pool for server IDs of length |server_id_len|, or
26 // nullptr if |server_id_len| is invalid.
27 static std::shared_ptr<LoadBalancerServerIdMap> Create(uint8_t server_id_len);
28
29 // Returns the entry associated with |server_id|, if present. For small |T|,
30 // use Lookup. For large |T|, use LookupNoCopy.
31 std::optional<const T> Lookup(LoadBalancerServerId server_id) const;
32 const T* LookupNoCopy(LoadBalancerServerId server_id) const;
33
34 // Updates the table so that |value| is associated with |server_id|. Sets
35 // QUIC_BUG if the length is incorrect for this map.
36 void AddOrReplace(LoadBalancerServerId server_id, T value);
37
38 // Removes the entry associated with |server_id|.
Erase(const LoadBalancerServerId server_id)39 void Erase(const LoadBalancerServerId server_id) {
40 server_id_table_.erase(server_id);
41 }
42
server_id_len()43 uint8_t server_id_len() const { return server_id_len_; }
44
45 private:
LoadBalancerServerIdMap(uint8_t server_id_len)46 LoadBalancerServerIdMap(uint8_t server_id_len)
47 : server_id_len_(server_id_len) {}
48
49 const uint8_t server_id_len_; // All server IDs must be of this length.
50 absl::flat_hash_map<LoadBalancerServerId, T> server_id_table_;
51 };
52
53 template <typename T>
Create(const uint8_t server_id_len)54 std::shared_ptr<LoadBalancerServerIdMap<T>> LoadBalancerServerIdMap<T>::Create(
55 const uint8_t server_id_len) {
56 if (server_id_len == 0 || server_id_len > kLoadBalancerMaxServerIdLen) {
57 QUIC_BUG(quic_bug_434893339_01)
58 << "Tried to configure map with server ID length "
59 << static_cast<int>(server_id_len);
60 return nullptr;
61 }
62 return std::make_shared<LoadBalancerServerIdMap<T>>(
63 LoadBalancerServerIdMap(server_id_len));
64 }
65
66 template <typename T>
Lookup(const LoadBalancerServerId server_id)67 std::optional<const T> LoadBalancerServerIdMap<T>::Lookup(
68 const LoadBalancerServerId server_id) const {
69 if (server_id.length() != server_id_len_) {
70 QUIC_BUG(quic_bug_434893339_02)
71 << "Lookup with a " << static_cast<int>(server_id.length())
72 << " byte server ID, map requires " << static_cast<int>(server_id_len_);
73 return std::optional<T>();
74 }
75 auto it = server_id_table_.find(server_id);
76 return (it != server_id_table_.end()) ? it->second : std::optional<const T>();
77 }
78
79 template <typename T>
LookupNoCopy(const LoadBalancerServerId server_id)80 const T* LoadBalancerServerIdMap<T>::LookupNoCopy(
81 const LoadBalancerServerId server_id) const {
82 if (server_id.length() != server_id_len_) {
83 QUIC_BUG(quic_bug_434893339_02)
84 << "Lookup with a " << static_cast<int>(server_id.length())
85 << " byte server ID, map requires " << static_cast<int>(server_id_len_);
86 return nullptr;
87 }
88 auto it = server_id_table_.find(server_id);
89 return (it != server_id_table_.end()) ? &it->second : nullptr;
90 }
91
92 template <typename T>
AddOrReplace(const LoadBalancerServerId server_id,T value)93 void LoadBalancerServerIdMap<T>::AddOrReplace(
94 const LoadBalancerServerId server_id, T value) {
95 if (server_id.length() == server_id_len_) {
96 server_id_table_[server_id] = value;
97 } else {
98 QUIC_BUG(quic_bug_434893339_03)
99 << "Server ID of " << static_cast<int>(server_id.length())
100 << " bytes; this map requires " << static_cast<int>(server_id_len_);
101 }
102 }
103
104 } // namespace quic
105
106 #endif // QUICHE_QUIC_LOAD_BALANCER_LOAD_BALANCER_SERVER_ID_MAP_H_
107