xref: /aosp_15_r20/external/tink/cc/aead/internal/zero_copy_aead_wrapper.cc (revision e7b1675dde1b92d52ec075b0a92829627f2c52a5)
1 // Copyright 2021 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 ///////////////////////////////////////////////////////////////////////////////
16 
17 #include "tink/aead/internal/zero_copy_aead_wrapper.h"
18 
19 #include <memory>
20 #include <string>
21 #include <utility>
22 #include <vector>
23 
24 #include "absl/strings/string_view.h"
25 #include "tink/crypto_format.h"
26 #include "tink/subtle/subtle_util.h"
27 
28 namespace crypto {
29 namespace tink {
30 namespace internal {
31 
32 namespace {
33 
34 typedef crypto::tink::PrimitiveSet<ZeroCopyAead>::Entry<ZeroCopyAead>
35     ZeroCopyAeadEntry;
36 
Validate(PrimitiveSet<ZeroCopyAead> * aead_set)37 util::Status Validate(PrimitiveSet<ZeroCopyAead>* aead_set) {
38   if (aead_set == nullptr) {
39     return util::Status(absl::StatusCode::kInternal,
40                         "aead_set must be non-NULL");
41   }
42   if (aead_set->get_primary() == nullptr) {
43     return util::Status(absl::StatusCode::kInvalidArgument,
44                         "aead_set has no primary");
45   }
46   return util::OkStatus();
47 }
48 
49 class ZeroCopyAeadSetWrapper : public Aead {
50  public:
ZeroCopyAeadSetWrapper(std::unique_ptr<PrimitiveSet<ZeroCopyAead>> aead_set)51   explicit ZeroCopyAeadSetWrapper(
52       std::unique_ptr<PrimitiveSet<ZeroCopyAead>> aead_set)
53       : aead_set_(std::move(aead_set)) {}
54 
55   util::StatusOr<std::string> Encrypt(
56       absl::string_view plaintext,
57       absl::string_view associated_data) const override;
58 
59   util::StatusOr<std::string> Decrypt(
60       absl::string_view ciphertext,
61       absl::string_view associated_data) const override;
62 
63   ~ZeroCopyAeadSetWrapper() override = default;
64 
65  private:
66   std::unique_ptr<PrimitiveSet<ZeroCopyAead>> aead_set_;
67 };
68 
Encrypt(absl::string_view plaintext,absl::string_view associated_data) const69 util::StatusOr<std::string> ZeroCopyAeadSetWrapper::Encrypt(
70     absl::string_view plaintext, absl::string_view associated_data) const {
71   std::string ciphertext = aead_set_->get_primary()->get_identifier();
72   int64_t key_id_size = ciphertext.size();
73   ZeroCopyAead& aead = aead_set_->get_primary()->get_primitive();
74   subtle::ResizeStringUninitialized(
75       &ciphertext, key_id_size + aead.MaxEncryptionSize(plaintext.size()));
76 
77   // Write ciphertext at position ciphertext + CryptoFormat::kNonRawPrefixSize.
78   util::StatusOr<int64_t> ciphertext_size =
79       aead.Encrypt(plaintext, associated_data,
80                    absl::MakeSpan(ciphertext).subspan(key_id_size));
81   if (!ciphertext_size.ok()) return ciphertext_size.status();
82   ciphertext.resize(key_id_size + *ciphertext_size);
83 
84   return ciphertext;
85 }
86 
Decrypt(absl::string_view ciphertext,absl::string_view associated_data) const87 util::StatusOr<std::string> ZeroCopyAeadSetWrapper::Decrypt(
88     absl::string_view ciphertext, absl::string_view associated_data) const {
89   if (ciphertext.size() > CryptoFormat::kNonRawPrefixSize) {
90     std::string key_id =
91         std::string(ciphertext.substr(0, CryptoFormat::kNonRawPrefixSize));
92     util::StatusOr<const std::vector<std::unique_ptr<ZeroCopyAeadEntry>>*>
93         primitives = aead_set_->get_primitives(key_id);
94 
95     if (primitives.ok() && *primitives != nullptr) {
96       absl::string_view raw_ciphertext =
97           ciphertext.substr(key_id.size(), ciphertext.size());
98 
99       for (const std::unique_ptr<ZeroCopyAeadEntry>& entry : **primitives) {
100         ZeroCopyAead& aead = entry->get_primitive();
101         std::string plaintext;
102         subtle::ResizeStringUninitialized(
103             &plaintext, aead.MaxDecryptionSize(raw_ciphertext.size()));
104         util::StatusOr<int64_t> plaintext_size = entry->get_primitive().Decrypt(
105             raw_ciphertext, associated_data, absl::MakeSpan(plaintext));
106         if (plaintext_size.ok()) {
107           plaintext.resize(*plaintext_size);
108           return plaintext;
109         }
110       }
111     }
112   }
113 
114   // Try raw keys because matching keys failed to decrypt.
115   util::StatusOr<const std::vector<std::unique_ptr<ZeroCopyAeadEntry>>*>
116       raw_primitives = aead_set_->get_raw_primitives();
117   if (raw_primitives.ok() && *raw_primitives != nullptr) {
118     for (const std::unique_ptr<ZeroCopyAeadEntry>& entry : **raw_primitives) {
119       ZeroCopyAead& aead = entry->get_primitive();
120       std::string plaintext;
121       subtle::ResizeStringUninitialized(
122           &plaintext, aead.MaxDecryptionSize(ciphertext.size()));
123       util::StatusOr<int64_t> plaintext_size =
124           aead.Decrypt(ciphertext, associated_data, absl::MakeSpan(plaintext));
125       if (plaintext_size.ok()) {
126         plaintext.resize(*plaintext_size);
127         return plaintext;
128       }
129     }
130   }
131 
132   return util::Status(absl::StatusCode::kInvalidArgument, "Decryption failed");
133 }
134 
135 }  // anonymous namespace
136 
Wrap(std::unique_ptr<PrimitiveSet<ZeroCopyAead>> aead_set) const137 util::StatusOr<std::unique_ptr<Aead>> ZeroCopyAeadWrapper::Wrap(
138     std::unique_ptr<PrimitiveSet<ZeroCopyAead>> aead_set) const {
139   util::Status status = Validate(aead_set.get());
140   if (!status.ok()) return status;
141   std::unique_ptr<Aead> aead =
142       absl::make_unique<ZeroCopyAeadSetWrapper>(std::move(aead_set));
143   return std::move(aead);
144 }
145 
146 }  // namespace internal
147 }  // namespace tink
148 }  // namespace crypto
149