1 // Copyright (c) 2022 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #ifndef QUICHE_QUIC_LOAD_BALANCER_LOAD_BALANCER_SERVER_ID_MAP_H_
6 #define QUICHE_QUIC_LOAD_BALANCER_LOAD_BALANCER_SERVER_ID_MAP_H_
7 
8 #include <cstdint>
9 #include <memory>
10 #include <optional>
11 
12 #include "absl/container/flat_hash_map.h"
13 #include "quiche/quic/load_balancer/load_balancer_server_id.h"
14 #include "quiche/quic/platform/api/quic_bug_tracker.h"
15 
16 namespace quic {
17 
18 // This class wraps an absl::flat_hash_map which associates server IDs to an
19 // arbitrary type T. It validates that all server ids are of the same fixed
20 // length. This might be used by a load balancer to connect a server ID with a
21 // pool member data structure.
22 template <typename T>
23 class QUIC_EXPORT_PRIVATE LoadBalancerServerIdMap {
24  public:
25   // Returns a newly created pool for server IDs of length |server_id_len|, or
26   // nullptr if |server_id_len| is invalid.
27   static std::shared_ptr<LoadBalancerServerIdMap> Create(uint8_t server_id_len);
28 
29   // Returns the entry associated with |server_id|, if present. For small |T|,
30   // use Lookup. For large |T|, use LookupNoCopy.
31   std::optional<const T> Lookup(LoadBalancerServerId server_id) const;
32   const T* LookupNoCopy(LoadBalancerServerId server_id) const;
33 
34   // Updates the table so that |value| is associated with |server_id|. Sets
35   // QUIC_BUG if the length is incorrect for this map.
36   void AddOrReplace(LoadBalancerServerId server_id, T value);
37 
38   // Removes the entry associated with |server_id|.
Erase(const LoadBalancerServerId server_id)39   void Erase(const LoadBalancerServerId server_id) {
40     server_id_table_.erase(server_id);
41   }
42 
server_id_len()43   uint8_t server_id_len() const { return server_id_len_; }
44 
45  private:
LoadBalancerServerIdMap(uint8_t server_id_len)46   LoadBalancerServerIdMap(uint8_t server_id_len)
47       : server_id_len_(server_id_len) {}
48 
49   const uint8_t server_id_len_;  // All server IDs must be of this length.
50   absl::flat_hash_map<LoadBalancerServerId, T> server_id_table_;
51 };
52 
53 template <typename T>
Create(const uint8_t server_id_len)54 std::shared_ptr<LoadBalancerServerIdMap<T>> LoadBalancerServerIdMap<T>::Create(
55     const uint8_t server_id_len) {
56   if (server_id_len == 0 || server_id_len > kLoadBalancerMaxServerIdLen) {
57     QUIC_BUG(quic_bug_434893339_01)
58         << "Tried to configure map with server ID length "
59         << static_cast<int>(server_id_len);
60     return nullptr;
61   }
62   return std::make_shared<LoadBalancerServerIdMap<T>>(
63       LoadBalancerServerIdMap(server_id_len));
64 }
65 
66 template <typename T>
Lookup(const LoadBalancerServerId server_id)67 std::optional<const T> LoadBalancerServerIdMap<T>::Lookup(
68     const LoadBalancerServerId server_id) const {
69   if (server_id.length() != server_id_len_) {
70     QUIC_BUG(quic_bug_434893339_02)
71         << "Lookup with a " << static_cast<int>(server_id.length())
72         << " byte server ID, map requires " << static_cast<int>(server_id_len_);
73     return std::optional<T>();
74   }
75   auto it = server_id_table_.find(server_id);
76   return (it != server_id_table_.end()) ? it->second : std::optional<const T>();
77 }
78 
79 template <typename T>
LookupNoCopy(const LoadBalancerServerId server_id)80 const T* LoadBalancerServerIdMap<T>::LookupNoCopy(
81     const LoadBalancerServerId server_id) const {
82   if (server_id.length() != server_id_len_) {
83     QUIC_BUG(quic_bug_434893339_02)
84         << "Lookup with a " << static_cast<int>(server_id.length())
85         << " byte server ID, map requires " << static_cast<int>(server_id_len_);
86     return nullptr;
87   }
88   auto it = server_id_table_.find(server_id);
89   return (it != server_id_table_.end()) ? &it->second : nullptr;
90 }
91 
92 template <typename T>
AddOrReplace(const LoadBalancerServerId server_id,T value)93 void LoadBalancerServerIdMap<T>::AddOrReplace(
94     const LoadBalancerServerId server_id, T value) {
95   if (server_id.length() == server_id_len_) {
96     server_id_table_[server_id] = value;
97   } else {
98     QUIC_BUG(quic_bug_434893339_03)
99         << "Server ID of " << static_cast<int>(server_id.length())
100         << " bytes; this map requires " << static_cast<int>(server_id_len_);
101   }
102 }
103 
104 }  // namespace quic
105 
106 #endif  // QUICHE_QUIC_LOAD_BALANCER_LOAD_BALANCER_SERVER_ID_MAP_H_
107