xref: /aosp_15_r20/external/grpc-grpc/test/core/end2end/fixtures/h2_oauth2_common.h (revision cc02d7e222339f7a4f6ba5f422e6413f4bd931f2)
1 // Copyright 2023 gRPC authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #ifndef GRPC_TEST_CORE_END2END_FIXTURES_H2_OAUTH2_COMMON_H
16 #define GRPC_TEST_CORE_END2END_FIXTURES_H2_OAUTH2_COMMON_H
17 
18 #include <string.h>
19 
20 #include <grpc/grpc.h>
21 #include <grpc/grpc_security.h>
22 #include <grpc/grpc_security_constants.h>
23 #include <grpc/impl/channel_arg_names.h>
24 #include <grpc/slice.h>
25 #include <grpc/status.h>
26 #include <grpc/support/log.h>
27 
28 #include "src/core/lib/channel/channel_args.h"
29 #include "src/core/lib/iomgr/error.h"
30 #include "src/core/lib/security/credentials/credentials.h"
31 #include "src/core/lib/security/credentials/ssl/ssl_credentials.h"
32 #include "test/core/end2end/end2end_tests.h"
33 #include "test/core/end2end/fixtures/secure_fixture.h"
34 #include "test/core/util/tls_utils.h"
35 
36 class Oauth2Fixture : public SecureFixture {
37  public:
Oauth2Fixture(grpc_tls_version tls_version)38   explicit Oauth2Fixture(grpc_tls_version tls_version)
39       : tls_version_(tls_version) {}
40 
CaCertPath()41   static const char* CaCertPath() { return "src/core/tsi/test_creds/ca.pem"; }
ServerCertPath()42   static const char* ServerCertPath() {
43     return "src/core/tsi/test_creds/server1.pem";
44   }
ServerKeyPath()45   static const char* ServerKeyPath() {
46     return "src/core/tsi/test_creds/server1.key";
47   }
48 
49  private:
50   struct TestProcessorState {};
51 
oauth2_md()52   static const char* oauth2_md() { return "Bearer aaslkfjs424535asdf"; }
client_identity_property_name()53   static const char* client_identity_property_name() { return "smurf_name"; }
client_identity()54   static const char* client_identity() { return "Brainy Smurf"; }
55 
find_metadata(const grpc_metadata * md,size_t md_count,const char * key,const char * value)56   static const grpc_metadata* find_metadata(const grpc_metadata* md,
57                                             size_t md_count, const char* key,
58                                             const char* value) {
59     size_t i;
60     for (i = 0; i < md_count; i++) {
61       if (grpc_slice_str_cmp(md[i].key, key) == 0 &&
62           grpc_slice_str_cmp(md[i].value, value) == 0) {
63         return &md[i];
64       }
65     }
66     return nullptr;
67   }
68 
process_oauth2_success(void *,grpc_auth_context *,const grpc_metadata * md,size_t md_count,grpc_process_auth_metadata_done_cb cb,void * user_data)69   static void process_oauth2_success(void*, grpc_auth_context*,
70                                      const grpc_metadata* md, size_t md_count,
71                                      grpc_process_auth_metadata_done_cb cb,
72                                      void* user_data) {
73     const grpc_metadata* oauth2 =
74         find_metadata(md, md_count, "authorization", oauth2_md());
75     GPR_ASSERT(oauth2 != nullptr);
76     cb(user_data, oauth2, 1, nullptr, 0, GRPC_STATUS_OK, nullptr);
77   }
78 
process_oauth2_failure(void * state,grpc_auth_context *,const grpc_metadata * md,size_t md_count,grpc_process_auth_metadata_done_cb cb,void * user_data)79   static void process_oauth2_failure(void* state, grpc_auth_context* /*ctx*/,
80                                      const grpc_metadata* md, size_t md_count,
81                                      grpc_process_auth_metadata_done_cb cb,
82                                      void* user_data) {
83     const grpc_metadata* oauth2 =
84         find_metadata(md, md_count, "authorization", oauth2_md());
85     GPR_ASSERT(state != nullptr);
86     GPR_ASSERT(oauth2 != nullptr);
87     cb(user_data, oauth2, 1, nullptr, 0, GRPC_STATUS_UNAUTHENTICATED, nullptr);
88   }
89 
test_processor_create(bool failing)90   static grpc_auth_metadata_processor test_processor_create(bool failing) {
91     auto* s = new TestProcessorState;
92     grpc_auth_metadata_processor result;
93     result.state = s;
94     result.destroy = [](void* p) {
95       delete static_cast<TestProcessorState*>(p);
96     };
97     if (failing) {
98       result.process = process_oauth2_failure;
99     } else {
100       result.process = process_oauth2_success;
101     }
102     return result;
103   }
104 
MutateClientArgs(grpc_core::ChannelArgs args)105   grpc_core::ChannelArgs MutateClientArgs(
106       grpc_core::ChannelArgs args) override {
107     return args.Set(GRPC_SSL_TARGET_NAME_OVERRIDE_ARG, "foo.test.google.fr");
108   }
109 
MakeClientCreds(const grpc_core::ChannelArgs &)110   grpc_channel_credentials* MakeClientCreds(
111       const grpc_core::ChannelArgs&) override {
112     std::string test_root_cert =
113         grpc_core::testing::GetFileContents(CaCertPath());
114     grpc_channel_credentials* ssl_creds = grpc_ssl_credentials_create(
115         test_root_cert.c_str(), nullptr, nullptr, nullptr);
116     if (ssl_creds != nullptr) {
117       // Set the min and max TLS version.
118       grpc_ssl_credentials* creds =
119           reinterpret_cast<grpc_ssl_credentials*>(ssl_creds);
120       creds->set_min_tls_version(tls_version_);
121       creds->set_max_tls_version(tls_version_);
122     }
123     grpc_call_credentials* oauth2_creds =
124         grpc_md_only_test_credentials_create("authorization", oauth2_md());
125     grpc_channel_credentials* ssl_oauth2_creds =
126         grpc_composite_channel_credentials_create(ssl_creds, oauth2_creds,
127                                                   nullptr);
128     grpc_channel_credentials_release(ssl_creds);
129     grpc_call_credentials_release(oauth2_creds);
130     return ssl_oauth2_creds;
131   }
132 
MakeServerCreds(const grpc_core::ChannelArgs & args)133   grpc_server_credentials* MakeServerCreds(
134       const grpc_core::ChannelArgs& args) override {
135     std::string server_cert =
136         grpc_core::testing::GetFileContents(ServerCertPath());
137     std::string server_key =
138         grpc_core::testing::GetFileContents(ServerKeyPath());
139     grpc_ssl_pem_key_cert_pair pem_key_cert_pair = {server_key.c_str(),
140                                                     server_cert.c_str()};
141     grpc_server_credentials* ssl_creds = grpc_ssl_server_credentials_create(
142         nullptr, &pem_key_cert_pair, 1, 0, nullptr);
143     if (ssl_creds != nullptr) {
144       // Set the min and max TLS version.
145       grpc_ssl_server_credentials* creds =
146           reinterpret_cast<grpc_ssl_server_credentials*>(ssl_creds);
147       creds->set_min_tls_version(tls_version_);
148       creds->set_max_tls_version(tls_version_);
149     }
150     grpc_server_credentials_set_auth_metadata_processor(
151         ssl_creds,
152         test_processor_create(args.Contains(FAIL_AUTH_CHECK_SERVER_ARG_NAME)));
153     return ssl_creds;
154   }
155 
process_auth_failure(void * state,grpc_auth_context *,const grpc_metadata *,size_t,grpc_process_auth_metadata_done_cb cb,void * user_data)156   static void process_auth_failure(void* state, grpc_auth_context* /*ctx*/,
157                                    const grpc_metadata* /*md*/,
158                                    size_t /*md_count*/,
159                                    grpc_process_auth_metadata_done_cb cb,
160                                    void* user_data) {
161     GPR_ASSERT(state == nullptr);
162     cb(user_data, nullptr, 0, nullptr, 0, GRPC_STATUS_UNAUTHENTICATED, nullptr);
163   }
164 
165   grpc_tls_version tls_version_;
166 };
167 
168 #endif  // GRPC_TEST_CORE_END2END_FIXTURES_H2_OAUTH2_COMMON_H
169