xref: /aosp_15_r20/external/pigweed/pw_rpc_transport/public/pw_rpc_transport/socket_rpc_transport.h (revision 61c4878ac05f98d0ceed94b57d316916de578985)
1 // Copyright 2023 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 //     https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14 #pragma once
15 
16 #include <signal.h>
17 
18 #include <atomic>
19 #include <mutex>
20 
21 #include "pw_assert/assert.h"
22 #include "pw_bytes/span.h"
23 #include "pw_chrono/system_clock.h"
24 #include "pw_rpc_transport/rpc_transport.h"
25 #include "pw_status/status.h"
26 #include "pw_status/try.h"
27 #include "pw_stream/socket_stream.h"
28 #include "pw_sync/condition_variable.h"
29 #include "pw_sync/lock_annotations.h"
30 #include "pw_sync/mutex.h"
31 #include "pw_sync/thread_notification.h"
32 #include "pw_thread/sleep.h"
33 #include "pw_thread/thread_core.h"
34 
35 namespace pw::rpc {
36 
37 namespace internal {
38 
39 void LogSocketListenError(Status);
40 void LogSocketAcceptError(Status);
41 void LogSocketConnectError(Status);
42 void LogSocketReadError(Status);
43 void LogSocketIngressHandlerError(Status);
44 
45 }  // namespace internal
46 
47 template <size_t kReadBufferSize>
48 class SocketRpcTransport : public RpcFrameSender, public thread::ThreadCore {
49  public:
50   struct AsServer {};
51   struct AsClient {};
52 
53   static constexpr AsServer kAsServer{};
54   static constexpr AsClient kAsClient{};
55 
SocketRpcTransport(AsServer,uint16_t port)56   SocketRpcTransport(AsServer, uint16_t port)
57       : role_(ClientServerRole::kServer), port_(port) {}
58 
SocketRpcTransport(AsServer,uint16_t port,RpcIngressHandler & ingress)59   SocketRpcTransport(AsServer, uint16_t port, RpcIngressHandler& ingress)
60       : role_(ClientServerRole::kServer), port_(port), ingress_(&ingress) {}
61 
SocketRpcTransport(AsClient,std::string_view host,uint16_t port)62   SocketRpcTransport(AsClient, std::string_view host, uint16_t port)
63       : role_(ClientServerRole::kClient), host_(host), port_(port) {}
64 
SocketRpcTransport(AsClient,std::string_view host,uint16_t port,RpcIngressHandler & ingress)65   SocketRpcTransport(AsClient,
66                      std::string_view host,
67                      uint16_t port,
68                      RpcIngressHandler& ingress)
69       : role_(ClientServerRole::kClient),
70         host_(host),
71         port_(port),
72         ingress_(&ingress) {}
73 
MaximumTransmissionUnit()74   size_t MaximumTransmissionUnit() const override { return kReadBufferSize; }
port()75   size_t port() const { return port_; }
set_ingress(RpcIngressHandler & ingress)76   void set_ingress(RpcIngressHandler& ingress) { ingress_ = &ingress; }
77 
Send(RpcFrame frame)78   Status Send(RpcFrame frame) override {
79     std::lock_guard lock(write_mutex_);
80     PW_TRY(socket_stream_.Write(frame.header));
81     PW_TRY(socket_stream_.Write(frame.payload));
82     return OkStatus();
83   }
84 
85   // Returns once the transport is connected to its peer.
WaitUntilConnected()86   void WaitUntilConnected() {
87     std::unique_lock lock(connected_mutex_);
88     connected_cv_.wait(lock, [this]() { return connected_; });
89   }
90 
91   // Returns once the transport is ready to be used (e.g. the server is
92   // listening on the port or the client is ready to connect).
WaitUntilReady()93   void WaitUntilReady() {
94     std::unique_lock lock(ready_mutex_);
95     ready_cv_.wait(lock, [this]() { return ready_; });
96   }
97 
Start()98   void Start() {
99     while (!stopped_) {
100       const auto connect_status = EstablishConnection();
101       if (!connect_status.ok()) {
102         this_thread::sleep_for(kConnectionRetryPeriod);
103         continue;
104       }
105       NotifyConnected();
106 
107       while (!stopped_) {
108         const auto read_status = ReadData();
109         // Break if ReadData was cancelled after the transport was stopped.
110         if (stopped_) {
111           break;
112         }
113         if (!read_status.ok()) {
114           internal::LogSocketReadError(read_status);
115         }
116         if (read_status.IsOutOfRange()) {
117           // Need to reconnect (we don't close the stream here because it's
118           // already done in SocketStream::DoRead).
119           {
120             std::lock_guard lock(connected_mutex_);
121             connected_ = false;
122           }
123           break;
124         }
125       }
126     }
127   }
128 
Stop()129   void Stop() {
130     stopped_ = true;
131     socket_stream_.Close();
132     server_socket_.Close();
133   }
134 
135  private:
136   enum class ClientServerRole { kClient, kServer };
137   static constexpr chrono::SystemClock::duration kConnectionRetryPeriod =
138       std::chrono::milliseconds(100);
139 
Run()140   void Run() override { Start(); }
141 
142   // Establishes or accepts a new socket connection. Returns when socket_stream_
143   // contains a valid socket connection, or when the transport is stopped.
EstablishConnection()144   Status EstablishConnection() {
145     if (role_ == ClientServerRole::kServer) {
146       return Serve();
147     }
148     return Connect();
149   }
150 
Serve()151   Status Serve() {
152     PW_DASSERT(role_ == ClientServerRole::kServer);
153 
154     if (!listening_) {
155       const auto listen_status = server_socket_.Listen(port_);
156       if (!listen_status.ok()) {
157         internal::LogSocketListenError(listen_status);
158         return listen_status;
159       }
160     }
161 
162     listening_ = true;
163     port_ = server_socket_.port();
164     NotifyReady();
165 
166     Result<stream::SocketStream> stream = server_socket_.Accept();
167     // If Accept was cancelled due to stopping the transport, return without
168     // error.
169     if (stopped_) {
170       return OkStatus();
171     }
172     if (!stream.ok()) {
173       internal::LogSocketAcceptError(stream.status());
174       return stream.status();
175     }
176     // Ensure that the writer is done writing before updating the stream.
177     std::lock_guard lock(write_mutex_);
178     socket_stream_ = std::move(*stream);
179     return OkStatus();
180   }
181 
Connect()182   Status Connect() {
183     PW_DASSERT(role_ == ClientServerRole::kClient);
184     NotifyReady();
185 
186     std::lock_guard lock(write_mutex_);
187     auto connect_status = socket_stream_.Connect(host_.c_str(), port_);
188     if (!connect_status.ok()) {
189       internal::LogSocketConnectError(connect_status);
190     }
191     return connect_status;
192   }
193 
ReadData()194   Status ReadData() {
195     PW_DASSERT(ingress_ != nullptr);
196     PW_TRY_ASSIGN(auto buffer, socket_stream_.Read(read_buffer_));
197     const auto ingress_status = ingress_->ProcessIncomingData(buffer);
198     if (!ingress_status.ok()) {
199       internal::LogSocketIngressHandlerError(ingress_status);
200     }
201     // ReadData only returns socket stream read errors; ingress errors are only
202     // logged.
203     return OkStatus();
204   }
205 
NotifyConnected()206   void NotifyConnected() {
207     {
208       std::lock_guard lock(connected_mutex_);
209       connected_ = true;
210     }
211     connected_cv_.notify_all();
212   }
213 
NotifyReady()214   void NotifyReady() {
215     {
216       std::lock_guard lock(ready_mutex_);
217       ready_ = true;
218     }
219     ready_cv_.notify_all();
220   }
221 
222   ClientServerRole role_;
223   const std::string host_;
224   std::atomic<uint16_t> port_;
225   RpcIngressHandler* ingress_ = nullptr;
226 
227   // write_mutex_ must be held by the thread performing socket writes.
228   sync::Mutex write_mutex_;
229   stream::SocketStream socket_stream_;
230   stream::ServerSocket server_socket_;
231 
232   sync::Mutex ready_mutex_;
233   sync::ConditionVariable ready_cv_;
234   bool ready_ = false;
235 
236   sync::Mutex connected_mutex_;
237   sync::ConditionVariable connected_cv_;
238   bool connected_ = false;
239 
240   std::atomic<bool> stopped_ = false;
241   bool listening_ = false;
242   std::array<std::byte, kReadBufferSize> read_buffer_{};
243 };
244 
245 }  // namespace pw::rpc
246