xref: /aosp_15_r20/external/grpc-grpc/test/cpp/util/cli_credentials.cc (revision cc02d7e222339f7a4f6ba5f422e6413f4bd931f2)
1 //
2 //
3 // Copyright 2016 gRPC authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //     http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 //
17 //
18 
19 #include "test/cpp/util/cli_credentials.h"
20 
21 #include "absl/flags/flag.h"
22 
23 #include <grpc/slice.h>
24 #include <grpc/support/log.h>
25 #include <grpcpp/support/slice.h>
26 
27 #include "src/core/lib/gprpp/crash.h"
28 #include "src/core/lib/gprpp/load_file.h"
29 
30 ABSL_RETIRED_FLAG(bool, enable_ssl, false,
31                   "Replaced by --channel_creds_type=ssl.");
32 ABSL_RETIRED_FLAG(bool, use_auth, false,
33                   "Replaced by --channel_creds_type=gdc.");
34 ABSL_RETIRED_FLAG(std::string, access_token, "",
35                   "Replaced by --call_creds=access_token=<token>.");
36 ABSL_FLAG(
37     std::string, ssl_target, "",
38     "If not empty, treat the server host name as this for ssl/tls certificate "
39     "validation.");
40 ABSL_FLAG(
41     std::string, ssl_client_cert, "",
42     "If not empty, load this PEM formatted client certificate file. Requires "
43     "use of --ssl_client_key.");
44 ABSL_FLAG(std::string, ssl_client_key, "",
45           "If not empty, load this PEM formatted private key. Requires use of "
46           "--ssl_client_cert");
47 ABSL_FLAG(
48     std::string, local_connect_type, "local_tcp",
49     "The type of local connections for which local channel credentials will "
50     "be applied. Should be local_tcp or uds.");
51 ABSL_FLAG(
52     std::string, channel_creds_type, "",
53     "The channel creds type: insecure, ssl, gdc (Google Default Credentials), "
54     "alts, or local.");
55 ABSL_FLAG(
56     std::string, call_creds, "",
57     "Call credentials to use: none (default), or access_token=<token>. If "
58     "provided, the call creds are composited on top of channel creds.");
59 
60 namespace grpc {
61 namespace testing {
62 
63 namespace {
64 
65 const char ACCESS_TOKEN_PREFIX[] = "access_token=";
66 constexpr int ACCESS_TOKEN_PREFIX_LEN =
67     sizeof(ACCESS_TOKEN_PREFIX) / sizeof(*ACCESS_TOKEN_PREFIX) - 1;
68 
IsAccessToken(const std::string & auth)69 bool IsAccessToken(const std::string& auth) {
70   return auth.length() > ACCESS_TOKEN_PREFIX_LEN &&
71          auth.compare(0, ACCESS_TOKEN_PREFIX_LEN, ACCESS_TOKEN_PREFIX) == 0;
72 }
73 
AccessToken(const std::string & auth)74 std::string AccessToken(const std::string& auth) {
75   if (!IsAccessToken(auth)) {
76     return "";
77   }
78   return std::string(auth, ACCESS_TOKEN_PREFIX_LEN);
79 }
80 
81 }  // namespace
82 
GetDefaultChannelCredsType() const83 std::string CliCredentials::GetDefaultChannelCredsType() const {
84   return "insecure";
85 }
86 
GetDefaultCallCreds() const87 std::string CliCredentials::GetDefaultCallCreds() const { return "none"; }
88 
89 std::shared_ptr<grpc::ChannelCredentials>
GetChannelCredentials() const90 CliCredentials::GetChannelCredentials() const {
91   if (absl::GetFlag(FLAGS_channel_creds_type) == "insecure") {
92     return grpc::InsecureChannelCredentials();
93   } else if (absl::GetFlag(FLAGS_channel_creds_type) == "ssl") {
94     grpc::SslCredentialsOptions ssl_creds_options;
95     // TODO(@Capstan): This won't affect Google Default Credentials using SSL.
96     if (!absl::GetFlag(FLAGS_ssl_client_cert).empty()) {
97       auto cert = grpc_core::LoadFile(absl::GetFlag(FLAGS_ssl_client_cert),
98                                       /*add_null_terminator=*/false);
99       if (!cert.ok()) {
100         gpr_log(GPR_ERROR, "error loading file %s: %s",
101                 absl::GetFlag(FLAGS_ssl_client_cert).c_str(),
102                 cert.status().ToString().c_str());
103       } else {
104         ssl_creds_options.pem_cert_chain = std::string(cert->as_string_view());
105       }
106     }
107     if (!absl::GetFlag(FLAGS_ssl_client_key).empty()) {
108       auto key = grpc_core::LoadFile(absl::GetFlag(FLAGS_ssl_client_key),
109                                      /*add_null_terminator=*/false);
110       if (!key.ok()) {
111         gpr_log(GPR_ERROR, "error loading file %s: %s",
112                 absl::GetFlag(FLAGS_ssl_client_key).c_str(),
113                 key.status().ToString().c_str());
114       } else {
115         ssl_creds_options.pem_private_key = std::string(key->as_string_view());
116       }
117     }
118     return grpc::SslCredentials(ssl_creds_options);
119   } else if (absl::GetFlag(FLAGS_channel_creds_type) == "gdc") {
120     return grpc::GoogleDefaultCredentials();
121   } else if (absl::GetFlag(FLAGS_channel_creds_type) == "alts") {
122     return grpc::experimental::AltsCredentials(
123         grpc::experimental::AltsCredentialsOptions());
124   } else if (absl::GetFlag(FLAGS_channel_creds_type) == "local") {
125     if (absl::GetFlag(FLAGS_local_connect_type) == "local_tcp") {
126       return grpc::experimental::LocalCredentials(LOCAL_TCP);
127     } else if (absl::GetFlag(FLAGS_local_connect_type) == "uds") {
128       return grpc::experimental::LocalCredentials(UDS);
129     } else {
130       fprintf(stderr,
131               "--local_connect_type=%s invalid; must be local_tcp or uds.\n",
132               absl::GetFlag(FLAGS_local_connect_type).c_str());
133     }
134   }
135   fprintf(stderr,
136           "--channel_creds_type=%s invalid; must be insecure, ssl, gdc, "
137           "alts, or local.\n",
138           absl::GetFlag(FLAGS_channel_creds_type).c_str());
139   return std::shared_ptr<grpc::ChannelCredentials>();
140 }
141 
GetCallCredentials() const142 std::shared_ptr<grpc::CallCredentials> CliCredentials::GetCallCredentials()
143     const {
144   if (IsAccessToken(absl::GetFlag(FLAGS_call_creds))) {
145     return grpc::AccessTokenCredentials(
146         AccessToken(absl::GetFlag(FLAGS_call_creds)));
147   }
148   if (absl::GetFlag(FLAGS_call_creds) == "none") {
149     // Nothing to do; creds, if any, are baked into the channel.
150     return std::shared_ptr<grpc::CallCredentials>();
151   }
152   fprintf(stderr,
153           "--call_creds=%s invalid; must be none "
154           "or access_token=<token>.\n",
155           absl::GetFlag(FLAGS_call_creds).c_str());
156   return std::shared_ptr<grpc::CallCredentials>();
157 }
158 
GetCredentials() const159 std::shared_ptr<grpc::ChannelCredentials> CliCredentials::GetCredentials()
160     const {
161   if (absl::GetFlag(FLAGS_call_creds).empty()) {
162     absl::SetFlag(&FLAGS_call_creds, GetDefaultCallCreds());
163   }
164   if (absl::GetFlag(FLAGS_channel_creds_type).empty()) {
165     absl::SetFlag(&FLAGS_channel_creds_type, GetDefaultChannelCredsType());
166   }
167   std::shared_ptr<grpc::ChannelCredentials> channel_creds =
168       GetChannelCredentials();
169   // Composite any call-type credentials on top of the base channel.
170   std::shared_ptr<grpc::CallCredentials> call_creds = GetCallCredentials();
171   return (channel_creds == nullptr || call_creds == nullptr)
172              ? channel_creds
173              : grpc::CompositeChannelCredentials(channel_creds, call_creds);
174 }
175 
GetCredentialUsage() const176 std::string CliCredentials::GetCredentialUsage() const {
177   return "    --ssl_target             ; Set server host for ssl validation\n"
178          "    --ssl_client_cert        ; Client cert for ssl\n"
179          "    --ssl_client_key         ; Client private key for ssl\n"
180          "    --local_connect_type     ; Set to local_tcp or uds\n"
181          "    --channel_creds_type     ; Set to insecure, ssl, gdc, alts, or "
182          "local\n"
183          "    --call_creds             ; Set to none, or"
184          " access_token=<token>\n";
185 }
186 
GetSslTargetNameOverride() const187 std::string CliCredentials::GetSslTargetNameOverride() const {
188   bool use_ssl = absl::GetFlag(FLAGS_channel_creds_type) == "ssl" ||
189                  absl::GetFlag(FLAGS_channel_creds_type) == "gdc";
190   return use_ssl ? absl::GetFlag(FLAGS_ssl_target) : "";
191 }
192 
193 }  // namespace testing
194 }  // namespace grpc
195