xref: /aosp_15_r20/external/tink/cc/aead/aead_wrapper.cc (revision e7b1675dde1b92d52ec075b0a92829627f2c52a5)
1 // Copyright 2017 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 ///////////////////////////////////////////////////////////////////////////////
16 
17 #include "tink/aead/aead_wrapper.h"
18 
19 #include <memory>
20 #include <string>
21 #include <utility>
22 
23 #include "absl/memory/memory.h"
24 #include "absl/status/status.h"
25 #include "absl/strings/str_cat.h"
26 #include "absl/strings/string_view.h"
27 #include "tink/aead.h"
28 #include "tink/crypto_format.h"
29 #include "tink/internal/monitoring_util.h"
30 #include "tink/internal/registry_impl.h"
31 #include "tink/internal/util.h"
32 #include "tink/monitoring/monitoring.h"
33 #include "tink/primitive_set.h"
34 #include "tink/util/status.h"
35 #include "tink/util/statusor.h"
36 
37 namespace crypto {
38 namespace tink {
39 namespace {
40 
41 constexpr absl::string_view kPrimitive = "aead";
42 constexpr absl::string_view kEncryptApi = "encrypt";
43 constexpr absl::string_view kDecryptApi = "decrypt";
44 
Validate(PrimitiveSet<Aead> * aead_set)45 util::Status Validate(PrimitiveSet<Aead>* aead_set) {
46   if (aead_set == nullptr) {
47     return util::Status(absl::StatusCode::kInternal,
48                         "aead_set must be non-NULL");
49   }
50   if (aead_set->get_primary() == nullptr) {
51     return util::Status(absl::StatusCode::kInvalidArgument,
52                         "aead_set has no primary");
53   }
54   return util::OkStatus();
55 }
56 
57 // The actual wrapper.
58 class AeadSetWrapper : public Aead {
59  public:
AeadSetWrapper(std::unique_ptr<PrimitiveSet<Aead>> aead_set,std::unique_ptr<MonitoringClient> monitoring_encryption_client=nullptr,std::unique_ptr<MonitoringClient> monitoring_decryption_client=nullptr)60   explicit AeadSetWrapper(
61       std::unique_ptr<PrimitiveSet<Aead>> aead_set,
62       std::unique_ptr<MonitoringClient> monitoring_encryption_client = nullptr,
63       std::unique_ptr<MonitoringClient> monitoring_decryption_client = nullptr)
64       : aead_set_(std::move(aead_set)),
65         monitoring_encryption_client_(std::move(monitoring_encryption_client)),
66         monitoring_decryption_client_(std::move(monitoring_decryption_client)) {
67   }
68 
69   util::StatusOr<std::string> Encrypt(
70       absl::string_view plaintext,
71       absl::string_view associated_data) const override;
72 
73   util::StatusOr<std::string> Decrypt(
74       absl::string_view ciphertext,
75       absl::string_view associated_data) const override;
76 
77  private:
78   std::unique_ptr<PrimitiveSet<Aead>> aead_set_;
79   std::unique_ptr<MonitoringClient> monitoring_encryption_client_;
80   std::unique_ptr<MonitoringClient> monitoring_decryption_client_;
81 };
82 
Encrypt(absl::string_view plaintext,absl::string_view associated_data) const83 util::StatusOr<std::string> AeadSetWrapper::Encrypt(
84     absl::string_view plaintext, absl::string_view associated_data) const {
85   associated_data = internal::EnsureStringNonNull(associated_data);
86   const Aead& primitive = aead_set_->get_primary()->get_primitive();
87   util::StatusOr<std::string> ciphertext =
88       primitive.Encrypt(plaintext, associated_data);
89   if (!ciphertext.ok()) {
90     if (monitoring_encryption_client_ != nullptr) {
91       monitoring_encryption_client_->LogFailure();
92     }
93     return ciphertext.status();
94   }
95   if (monitoring_encryption_client_ != nullptr) {
96     monitoring_encryption_client_->Log(aead_set_->get_primary()->get_key_id(),
97                                        plaintext.size());
98   }
99   const std::string& key_id = aead_set_->get_primary()->get_identifier();
100   return absl::StrCat(key_id, *ciphertext);
101 }
102 
Decrypt(absl::string_view ciphertext,absl::string_view associated_data) const103 util::StatusOr<std::string> AeadSetWrapper::Decrypt(
104     absl::string_view ciphertext, absl::string_view associated_data) const {
105   // BoringSSL expects a non-null pointer for plaintext and associated_data,
106   // regardless of whether the size is 0.
107   associated_data = internal::EnsureStringNonNull(associated_data);
108 
109   if (ciphertext.length() > CryptoFormat::kNonRawPrefixSize) {
110     absl::string_view key_id =
111         ciphertext.substr(0, CryptoFormat::kNonRawPrefixSize);
112     util::StatusOr<const PrimitiveSet<Aead>::Primitives*> primitives =
113         aead_set_->get_primitives(key_id);
114     if (primitives.ok()) {
115       absl::string_view raw_ciphertext =
116           ciphertext.substr(CryptoFormat::kNonRawPrefixSize);
117       for (const std::unique_ptr<PrimitiveSet<Aead>::Entry<Aead>>& aead_entry :
118            **primitives) {
119         Aead& aead = aead_entry->get_primitive();
120         util::StatusOr<std::string> plaintext =
121             aead.Decrypt(raw_ciphertext, associated_data);
122         if (plaintext.ok()) {
123           if (monitoring_decryption_client_ != nullptr) {
124             monitoring_decryption_client_->Log(aead_entry->get_key_id(),
125                                                raw_ciphertext.size());
126           }
127           return plaintext;
128         }
129       }
130     }
131   }
132 
133   // No matching key succeeded with decryption, try all RAW keys.
134   util::StatusOr<const PrimitiveSet<Aead>::Primitives*> raw_primitives =
135       aead_set_->get_raw_primitives();
136   if (raw_primitives.ok()) {
137     for (const std::unique_ptr<PrimitiveSet<Aead>::Entry<Aead>>& aead_entry :
138          **raw_primitives) {
139       Aead& aead = aead_entry->get_primitive();
140       util::StatusOr<std::string> plaintext =
141           aead.Decrypt(ciphertext, associated_data);
142       if (plaintext.ok()) {
143         if (monitoring_decryption_client_ != nullptr) {
144           monitoring_decryption_client_->Log(aead_entry->get_key_id(),
145                                              ciphertext.size());
146         }
147         return plaintext;
148       }
149     }
150   }
151   if (monitoring_decryption_client_ != nullptr) {
152     monitoring_decryption_client_->LogFailure();
153   }
154   return util::Status(absl::StatusCode::kInvalidArgument, "decryption failed");
155 }
156 
157 }  // namespace
158 
Wrap(std::unique_ptr<PrimitiveSet<Aead>> aead_set) const159 util::StatusOr<std::unique_ptr<Aead>> AeadWrapper::Wrap(
160     std::unique_ptr<PrimitiveSet<Aead>> aead_set) const {
161   util::Status status = Validate(aead_set.get());
162   if (!status.ok()) {
163     return status;
164   }
165 
166   MonitoringClientFactory* const monitoring_factory =
167       internal::RegistryImpl::GlobalInstance().GetMonitoringClientFactory();
168 
169   // Monitoring is not enabled. Create a wrapper without monitoring clients.
170   if (monitoring_factory == nullptr) {
171     return {absl::make_unique<AeadSetWrapper>(std::move(aead_set))};
172   }
173 
174   util::StatusOr<MonitoringKeySetInfo> keyset_info =
175       internal::MonitoringKeySetInfoFromPrimitiveSet(*aead_set);
176   if (!keyset_info.ok()) {
177     return keyset_info.status();
178   }
179 
180   util::StatusOr<std::unique_ptr<MonitoringClient>>
181       monitoring_encryption_client = monitoring_factory->New(
182           MonitoringContext(kPrimitive, kEncryptApi, *keyset_info));
183   if (!monitoring_encryption_client.ok()) {
184     return monitoring_encryption_client.status();
185   }
186 
187   util::StatusOr<std::unique_ptr<MonitoringClient>>
188       monitoring_decryption_client = monitoring_factory->New(
189           MonitoringContext(kPrimitive, kDecryptApi, *keyset_info));
190   if (!monitoring_decryption_client.ok()) {
191     return monitoring_decryption_client.status();
192   }
193 
194   return {absl::make_unique<AeadSetWrapper>(
195       std::move(aead_set), *std::move(monitoring_encryption_client),
196       *std::move(monitoring_decryption_client))};
197 }
198 
199 }  // namespace tink
200 }  // namespace crypto
201