xref: /aosp_15_r20/external/pigweed/pw_rpc_transport/public/pw_rpc_transport/simple_framing.h (revision 61c4878ac05f98d0ceed94b57d316916de578985)
1 // Copyright 2023 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 //     https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14 #pragma once
15 
16 #include "pw_assert/assert.h"
17 #include "pw_bytes/span.h"
18 #include "pw_rpc_transport/rpc_transport.h"
19 #include "pw_status/status.h"
20 #include "pw_status/try.h"
21 
22 namespace pw::rpc {
23 
24 // The following encoder and decoder implement a very simple RPC framing
25 // protocol where the first frame contains the total packet size in the header
26 // and up to max frame size bytes in the payload. The subsequent frames of the
27 // same packet have an empty header and the rest of the packet in their payload.
28 //
29 // First frame header also contains a special marker as an attempt to
30 // resynchronize the receiver if some frames were not sent (although we expect
31 // all transports using this framing type to be reliable, it's still possible
32 // that some random transport write timeout result in only the first few frames
33 // being sent and others dropped; in that case we attempt best effort recovery
34 // by effectively skipping the input until we see something that resembles a
35 // valid header).
36 //
37 // Both encoder and decoder are not thread-safe. The caller must ensure their
38 // correct use in a multi-threaded environment.
39 
40 namespace internal {
41 
42 void LogReceivedRpcPacketTooLarge(size_t packet_size, size_t max_packet_size);
43 void LogMalformedRpcFrameHeader();
44 
45 }  // namespace internal
46 
47 template <size_t kMaxPacketSize>
48 class SimpleRpcPacketEncoder
49     : public RpcPacketEncoder<SimpleRpcPacketEncoder<kMaxPacketSize>> {
50   static_assert(kMaxPacketSize <= 1 << 16);
51 
52  public:
53   static constexpr size_t kHeaderSize = 4;
54   static constexpr uint16_t kFrameMarker = 0x27f1;
55 
56   // Encodes `packet` with a simple framing protocol and split the resulting
57   // frame into chunks of `RpcFrame`s where every `RpcFrame` is no longer than
58   // `max_frame_size`. Calls `callback` for for each of the resulting
59   // `RpcFrame`s.
Encode(ConstByteSpan rpc_packet,size_t max_frame_size,OnRpcFrameEncodedCallback && callback)60   Status Encode(ConstByteSpan rpc_packet,
61                 size_t max_frame_size,
62                 OnRpcFrameEncodedCallback&& callback) {
63     if (rpc_packet.size() > kMaxPacketSize) {
64       return Status::FailedPrecondition();
65     }
66     if (max_frame_size <= kHeaderSize) {
67       return Status::FailedPrecondition();
68     }
69 
70     // First frame. This is the only frame that contains a header.
71     const auto first_frame_size =
72         std::min(max_frame_size - kHeaderSize, rpc_packet.size());
73 
74     std::array<std::byte, kHeaderSize> header{
75         std::byte{kFrameMarker & 0xff},
76         std::byte{(kFrameMarker >> 8) & 0xff},
77         static_cast<std::byte>(rpc_packet.size() & 0xff),
78         static_cast<std::byte>((rpc_packet.size() >> 8) & 0xff),
79     };
80 
81     RpcFrame frame{.header = span(header),
82                    .payload = rpc_packet.first(first_frame_size)};
83     PW_TRY(callback(frame));
84     auto remaining = rpc_packet.subspan(first_frame_size);
85 
86     // Second and subsequent frames (if any).
87     while (!remaining.empty()) {
88       auto fragment_size = std::min(max_frame_size, remaining.size());
89       RpcFrame next_frame{.header = {},
90                           .payload = remaining.first(fragment_size)};
91       PW_TRY(callback(next_frame));
92       remaining = remaining.subspan(fragment_size);
93     }
94 
95     return OkStatus();
96   }
97 };
98 
99 template <size_t kMaxPacketSize>
100 class SimpleRpcPacketDecoder
101     : public RpcPacketDecoder<SimpleRpcPacketDecoder<kMaxPacketSize>> {
102   using Encoder = SimpleRpcPacketEncoder<kMaxPacketSize>;
103 
104  public:
SimpleRpcPacketDecoder()105   SimpleRpcPacketDecoder() { ExpectHeader(); }
106 
107   // Find and decodes `RpcFrame`s in `buffer`. `buffer` may contain zero or
108   // more frames for zero or more packets. Calls `callback` for each
109   // well-formed packet. Malformed packets are ignored and dropped.
Decode(ConstByteSpan buffer,OnRpcPacketDecodedCallback && callback)110   Status Decode(ConstByteSpan buffer, OnRpcPacketDecodedCallback&& callback) {
111     while (!buffer.empty()) {
112       switch (state_) {
113         case State::kReadingHeader: {
114           buffer = buffer.subspan(ReadHeader(buffer));
115           break;
116         }
117         case State::kReadingPayload: {
118           // Payload can only follow a valid header, reset the flag here so
119           // that next invalid header logs again.
120           already_logged_invalid_header_ = false;
121           buffer = buffer.subspan(ReadPayload(buffer, callback));
122           break;
123         }
124       }
125     }
126     return OkStatus();
127   }
128 
129  private:
130   enum class State {
131     kReadingHeader,
132     kReadingPayload,
133   };
134 
135   size_t ReadHeader(ConstByteSpan buffer);
136 
137   size_t ReadPayload(ConstByteSpan buffer,
138                      const OnRpcPacketDecodedCallback& callback);
139 
ExpectHeader()140   void ExpectHeader() {
141     state_ = State::kReadingHeader;
142     bytes_read_ = 0;
143     bytes_remaining_ = Encoder::kHeaderSize;
144   }
145 
ExpectPayload(size_t size)146   void ExpectPayload(size_t size) {
147     state_ = State::kReadingPayload;
148     bytes_read_ = 0;
149     bytes_remaining_ = size;
150   }
151 
152   std::array<std::byte, kMaxPacketSize> packet_{};
153   std::array<std::byte, Encoder::kHeaderSize> header_{};
154 
155   // Current decoder state.
156   State state_;
157   // How many bytes were read in the current state.
158   size_t bytes_read_ = 0;
159   // How many bytes remain to read in the current state.
160   size_t bytes_remaining_ = 0;
161   // When true, discard the received payload instead of buffering it (because
162   // it's too big to buffer).
163   bool discard_payload_ = false;
164   // When true, skip logging on invalid header if we already logged. This is
165   // to prevent logging on every payload byte of a malformed frame.
166   bool already_logged_invalid_header_ = false;
167 };
168 
169 template <size_t kMaxPacketSize>
ReadHeader(ConstByteSpan buffer)170 size_t SimpleRpcPacketDecoder<kMaxPacketSize>::ReadHeader(
171     ConstByteSpan buffer) {
172   const auto read_size = std::min(buffer.size(), bytes_remaining_);
173   bool header_available = false;
174   PW_DASSERT(read_size <= Encoder::kHeaderSize);
175 
176   std::memcpy(header_.data() + bytes_read_, buffer.data(), read_size);
177   bytes_read_ += read_size;
178   bytes_remaining_ -= read_size;
179   header_available = bytes_remaining_ == 0;
180 
181   if (header_available) {
182     uint16_t marker = (static_cast<uint16_t>(header_[1]) << 8) |
183                       static_cast<uint16_t>(header_[0]);
184     uint16_t packet_size = (static_cast<uint16_t>(header_[3]) << 8) |
185                            static_cast<uint16_t>(header_[2]);
186 
187     if (marker != Encoder::kFrameMarker) {
188       // We expected a header but received some data that is definitely not
189       // a header. Skip it and keep waiting for the next header. This could
190       // also be a false positive, e.g. in the worst case we treat some
191       // random data as a header: even then we should eventually be able to
192       // stumble upon a real header and start processing packets again.
193       ExpectHeader();
194       // Consume only a single byte since we're looking for a header in a
195       // broken stream and it could start at the next byte.
196       if (!already_logged_invalid_header_) {
197         internal::LogMalformedRpcFrameHeader();
198         already_logged_invalid_header_ = true;
199       }
200       return 1;
201     }
202     if (packet_size > kMaxPacketSize) {
203       // Consume both header and packet without saving it, as it's too big
204       // for the buffer. This is likely due to max packet size mismatch
205       // between the encoder and the decoder.
206       internal::LogReceivedRpcPacketTooLarge(packet_size, kMaxPacketSize);
207       discard_payload_ = true;
208     }
209     ExpectPayload(packet_size);
210   }
211 
212   return read_size;
213 }
214 
215 template <size_t kMaxPacketSize>
ReadPayload(ConstByteSpan buffer,const OnRpcPacketDecodedCallback & callback)216 size_t SimpleRpcPacketDecoder<kMaxPacketSize>::ReadPayload(
217     ConstByteSpan buffer, const OnRpcPacketDecodedCallback& callback) {
218   if (buffer.size() >= bytes_remaining_ && bytes_read_ == 0) {
219     const auto read_size = bytes_remaining_;
220     // We have the whole packet available in the buffer, no need to copy
221     // it into an internal buffer.
222     callback(buffer.first(read_size));
223     ExpectHeader();
224     return read_size;
225   }
226   // Frame has been split between multiple inputs: assembling it in
227   // an internal buffer.
228   const auto read_size = std::min(buffer.size(), bytes_remaining_);
229 
230   // We could be discarding the payload if it was too big to fit into our
231   // packet buffer.
232   if (!discard_payload_) {
233     PW_DASSERT(bytes_read_ + read_size <= packet_.size());
234     std::memcpy(packet_.data() + bytes_read_, buffer.data(), read_size);
235   }
236 
237   bytes_read_ += read_size;
238   bytes_remaining_ -= read_size;
239   if (bytes_remaining_ == 0) {
240     if (discard_payload_) {
241       discard_payload_ = false;
242     } else {
243       auto packet_span = span(packet_);
244       callback(packet_span.first(bytes_read_));
245     }
246     ExpectHeader();
247   }
248   return read_size;
249 }
250 
251 }  // namespace pw::rpc
252