1 // Copyright 2020 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 // https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14
15 #include "pw_protobuf/decoder.h"
16
17 #include <cstring>
18
19 #include "pw_assert/check.h"
20 #include "pw_varint/varint.h"
21
22 namespace pw::protobuf {
23
Next()24 Status Decoder::Next() {
25 if (!previous_field_consumed_) {
26 if (Status status = SkipField(); !status.ok()) {
27 return status;
28 }
29 }
30 if (proto_.empty()) {
31 return Status::OutOfRange();
32 }
33 previous_field_consumed_ = false;
34 return GetFieldSize().ok() ? OkStatus() : Status::DataLoss();
35 }
36
SkipField()37 Status Decoder::SkipField() {
38 if (proto_.empty()) {
39 return Status::OutOfRange();
40 }
41
42 size_t bytes_to_skip = GetFieldSize().total();
43 if (bytes_to_skip == 0) {
44 return Status::DataLoss();
45 }
46
47 proto_ = proto_.subspan(bytes_to_skip);
48 return proto_.empty() ? Status::OutOfRange() : OkStatus();
49 }
50
FieldNumber() const51 uint32_t Decoder::FieldNumber() const {
52 uint64_t key;
53 varint::Decode(proto_, &key);
54 if (!FieldKey::IsValidKey(key)) {
55 return 0;
56 }
57 PW_DCHECK(key <= std::numeric_limits<uint32_t>::max());
58 return FieldKey(static_cast<uint32_t>(key)).field_number();
59 }
60
ReadUint32(uint32_t * out)61 Status Decoder::ReadUint32(uint32_t* out) {
62 uint64_t value = 0;
63 Status status = ReadUint64(&value);
64 if (!status.ok()) {
65 return status;
66 }
67 if (value > std::numeric_limits<uint32_t>::max()) {
68 return Status::OutOfRange();
69 }
70 *out = static_cast<uint32_t>(value);
71 return OkStatus();
72 }
73
ReadSint32(int32_t * out)74 Status Decoder::ReadSint32(int32_t* out) {
75 int64_t value = 0;
76 Status status = ReadSint64(&value);
77 if (!status.ok()) {
78 return status;
79 }
80 if (value > std::numeric_limits<int32_t>::max()) {
81 return Status::OutOfRange();
82 }
83 *out = static_cast<uint32_t>(value);
84 return OkStatus();
85 }
86
ReadSint64(int64_t * out)87 Status Decoder::ReadSint64(int64_t* out) {
88 uint64_t value = 0;
89 Status status = ReadUint64(&value);
90 if (!status.ok()) {
91 return status;
92 }
93 *out = varint::ZigZagDecode(value);
94 return OkStatus();
95 }
96
ReadBool(bool * out)97 Status Decoder::ReadBool(bool* out) {
98 uint64_t value = 0;
99 Status status = ReadUint64(&value);
100 if (!status.ok()) {
101 return status;
102 }
103 *out = value;
104 return OkStatus();
105 }
106
ReadString(std::string_view * out)107 Status Decoder::ReadString(std::string_view* out) {
108 span<const std::byte> bytes;
109 Status status = ReadDelimited(&bytes);
110 if (!status.ok()) {
111 return status;
112 }
113 *out = std::string_view(reinterpret_cast<const char*>(bytes.data()),
114 bytes.size());
115 return OkStatus();
116 }
117
GetFieldSize() const118 Decoder::FieldSize Decoder::GetFieldSize() const {
119 uint64_t key;
120 size_t key_size = varint::Decode(proto_, &key);
121 if (key_size == 0 || !FieldKey::IsValidKey(key)) {
122 return FieldSize::Invalid();
123 }
124
125 span<const std::byte> remainder = proto_.subspan(key_size);
126 uint64_t value = 0;
127 size_t expected_size = 0;
128
129 PW_DCHECK(key <= std::numeric_limits<uint32_t>::max());
130 switch (FieldKey(static_cast<uint32_t>(key)).wire_type()) {
131 case WireType::kVarint:
132 expected_size = varint::Decode(remainder, &value);
133 if (expected_size == 0) {
134 return FieldSize::Invalid();
135 }
136 break;
137
138 case WireType::kDelimited: {
139 // Varint at cursor indicates size of the field.
140 const size_t delimited_size = varint::Decode(remainder, &value);
141 if (delimited_size == 0) {
142 return FieldSize::Invalid();
143 }
144 key_size += delimited_size;
145 expected_size += value;
146 break;
147 }
148 case WireType::kFixed32:
149 expected_size = sizeof(uint32_t);
150 break;
151
152 case WireType::kFixed64:
153 expected_size = sizeof(uint64_t);
154 break;
155 }
156
157 if (remainder.size() < expected_size) {
158 return FieldSize::Invalid();
159 }
160
161 return FieldSize{key_size, expected_size};
162 }
163
ConsumeKey(WireType expected_type)164 Status Decoder::ConsumeKey(WireType expected_type) {
165 uint64_t key;
166 size_t bytes_read = varint::Decode(proto_, &key);
167 if (bytes_read == 0) {
168 return Status::FailedPrecondition();
169 }
170
171 if (!FieldKey::IsValidKey(key)) {
172 return Status::DataLoss();
173 }
174
175 PW_DCHECK(key <= std::numeric_limits<uint32_t>::max());
176 if (FieldKey(static_cast<uint32_t>(key)).wire_type() != expected_type) {
177 return Status::FailedPrecondition();
178 }
179
180 // Advance past the key.
181 proto_ = proto_.subspan(bytes_read);
182 return OkStatus();
183 }
184
ReadVarint(uint64_t * out)185 Status Decoder::ReadVarint(uint64_t* out) {
186 if (Status status = ConsumeKey(WireType::kVarint); !status.ok()) {
187 return status;
188 }
189
190 size_t bytes_read = varint::Decode(proto_, out);
191 if (bytes_read == 0) {
192 return Status::DataLoss();
193 }
194
195 // Advance to the next field.
196 proto_ = proto_.subspan(bytes_read);
197 previous_field_consumed_ = true;
198 return OkStatus();
199 }
200
ReadFixed(std::byte * out,size_t size)201 Status Decoder::ReadFixed(std::byte* out, size_t size) {
202 WireType expected_wire_type =
203 size == sizeof(uint32_t) ? WireType::kFixed32 : WireType::kFixed64;
204 Status status = ConsumeKey(expected_wire_type);
205 if (!status.ok()) {
206 return status;
207 }
208
209 if (proto_.size() < size) {
210 return Status::DataLoss();
211 }
212
213 std::memcpy(out, proto_.data(), size);
214 proto_ = proto_.subspan(size);
215 previous_field_consumed_ = true;
216
217 return OkStatus();
218 }
219
ReadDelimited(span<const std::byte> * out)220 Status Decoder::ReadDelimited(span<const std::byte>* out) {
221 Status status = ConsumeKey(WireType::kDelimited);
222 if (!status.ok()) {
223 return status;
224 }
225
226 uint64_t length;
227 size_t bytes_read = varint::Decode(proto_, &length);
228 if (bytes_read == 0) {
229 return Status::DataLoss();
230 }
231
232 proto_ = proto_.subspan(bytes_read);
233 if (proto_.size() < length) {
234 return Status::DataLoss();
235 }
236
237 *out = proto_.first(length);
238 proto_ = proto_.subspan(length);
239 previous_field_consumed_ = true;
240
241 return OkStatus();
242 }
243
Decode(span<const std::byte> proto)244 Status CallbackDecoder::Decode(span<const std::byte> proto) {
245 if (handler_ == nullptr || state_ != kReady) {
246 return Status::FailedPrecondition();
247 }
248
249 state_ = kDecodeInProgress;
250 decoder_.Reset(proto);
251
252 // Iterate the proto, calling the handler with each field number.
253 while (state_ == kDecodeInProgress) {
254 if (Status status = decoder_.Next(); !status.ok()) {
255 if (status.IsOutOfRange()) {
256 // Reached the end of the proto.
257 break;
258 }
259
260 // Proto data is malformed.
261 return status;
262 }
263
264 Status status = handler_->ProcessField(*this, decoder_.FieldNumber());
265 if (!status.ok()) {
266 state_ = status.IsCancelled() ? kDecodeCancelled : kDecodeFailed;
267 return status;
268 }
269
270 // The callback function can modify the decoder's state; check that
271 // everything is still okay.
272 if (state_ == kDecodeFailed) {
273 break;
274 }
275 }
276
277 if (state_ != kDecodeInProgress) {
278 return Status::DataLoss();
279 }
280
281 state_ = kReady;
282 return OkStatus();
283 }
284
285 } // namespace pw::protobuf
286