xref: /aosp_15_r20/external/pigweed/pw_protobuf/decoder.cc (revision 61c4878ac05f98d0ceed94b57d316916de578985)
1 // Copyright 2020 The Pigweed Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 // use this file except in compliance with the License. You may obtain a copy of
5 // the License at
6 //
7 //     https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 // License for the specific language governing permissions and limitations under
13 // the License.
14 
15 #include "pw_protobuf/decoder.h"
16 
17 #include <cstring>
18 
19 #include "pw_assert/check.h"
20 #include "pw_varint/varint.h"
21 
22 namespace pw::protobuf {
23 
Next()24 Status Decoder::Next() {
25   if (!previous_field_consumed_) {
26     if (Status status = SkipField(); !status.ok()) {
27       return status;
28     }
29   }
30   if (proto_.empty()) {
31     return Status::OutOfRange();
32   }
33   previous_field_consumed_ = false;
34   return GetFieldSize().ok() ? OkStatus() : Status::DataLoss();
35 }
36 
SkipField()37 Status Decoder::SkipField() {
38   if (proto_.empty()) {
39     return Status::OutOfRange();
40   }
41 
42   size_t bytes_to_skip = GetFieldSize().total();
43   if (bytes_to_skip == 0) {
44     return Status::DataLoss();
45   }
46 
47   proto_ = proto_.subspan(bytes_to_skip);
48   return proto_.empty() ? Status::OutOfRange() : OkStatus();
49 }
50 
FieldNumber() const51 uint32_t Decoder::FieldNumber() const {
52   uint64_t key;
53   varint::Decode(proto_, &key);
54   if (!FieldKey::IsValidKey(key)) {
55     return 0;
56   }
57   PW_DCHECK(key <= std::numeric_limits<uint32_t>::max());
58   return FieldKey(static_cast<uint32_t>(key)).field_number();
59 }
60 
ReadUint32(uint32_t * out)61 Status Decoder::ReadUint32(uint32_t* out) {
62   uint64_t value = 0;
63   Status status = ReadUint64(&value);
64   if (!status.ok()) {
65     return status;
66   }
67   if (value > std::numeric_limits<uint32_t>::max()) {
68     return Status::OutOfRange();
69   }
70   *out = static_cast<uint32_t>(value);
71   return OkStatus();
72 }
73 
ReadSint32(int32_t * out)74 Status Decoder::ReadSint32(int32_t* out) {
75   int64_t value = 0;
76   Status status = ReadSint64(&value);
77   if (!status.ok()) {
78     return status;
79   }
80   if (value > std::numeric_limits<int32_t>::max()) {
81     return Status::OutOfRange();
82   }
83   *out = static_cast<uint32_t>(value);
84   return OkStatus();
85 }
86 
ReadSint64(int64_t * out)87 Status Decoder::ReadSint64(int64_t* out) {
88   uint64_t value = 0;
89   Status status = ReadUint64(&value);
90   if (!status.ok()) {
91     return status;
92   }
93   *out = varint::ZigZagDecode(value);
94   return OkStatus();
95 }
96 
ReadBool(bool * out)97 Status Decoder::ReadBool(bool* out) {
98   uint64_t value = 0;
99   Status status = ReadUint64(&value);
100   if (!status.ok()) {
101     return status;
102   }
103   *out = value;
104   return OkStatus();
105 }
106 
ReadString(std::string_view * out)107 Status Decoder::ReadString(std::string_view* out) {
108   span<const std::byte> bytes;
109   Status status = ReadDelimited(&bytes);
110   if (!status.ok()) {
111     return status;
112   }
113   *out = std::string_view(reinterpret_cast<const char*>(bytes.data()),
114                           bytes.size());
115   return OkStatus();
116 }
117 
GetFieldSize() const118 Decoder::FieldSize Decoder::GetFieldSize() const {
119   uint64_t key;
120   size_t key_size = varint::Decode(proto_, &key);
121   if (key_size == 0 || !FieldKey::IsValidKey(key)) {
122     return FieldSize::Invalid();
123   }
124 
125   span<const std::byte> remainder = proto_.subspan(key_size);
126   uint64_t value = 0;
127   size_t expected_size = 0;
128 
129   PW_DCHECK(key <= std::numeric_limits<uint32_t>::max());
130   switch (FieldKey(static_cast<uint32_t>(key)).wire_type()) {
131     case WireType::kVarint:
132       expected_size = varint::Decode(remainder, &value);
133       if (expected_size == 0) {
134         return FieldSize::Invalid();
135       }
136       break;
137 
138     case WireType::kDelimited: {
139       // Varint at cursor indicates size of the field.
140       const size_t delimited_size = varint::Decode(remainder, &value);
141       if (delimited_size == 0) {
142         return FieldSize::Invalid();
143       }
144       key_size += delimited_size;
145       expected_size += value;
146       break;
147     }
148     case WireType::kFixed32:
149       expected_size = sizeof(uint32_t);
150       break;
151 
152     case WireType::kFixed64:
153       expected_size = sizeof(uint64_t);
154       break;
155   }
156 
157   if (remainder.size() < expected_size) {
158     return FieldSize::Invalid();
159   }
160 
161   return FieldSize{key_size, expected_size};
162 }
163 
ConsumeKey(WireType expected_type)164 Status Decoder::ConsumeKey(WireType expected_type) {
165   uint64_t key;
166   size_t bytes_read = varint::Decode(proto_, &key);
167   if (bytes_read == 0) {
168     return Status::FailedPrecondition();
169   }
170 
171   if (!FieldKey::IsValidKey(key)) {
172     return Status::DataLoss();
173   }
174 
175   PW_DCHECK(key <= std::numeric_limits<uint32_t>::max());
176   if (FieldKey(static_cast<uint32_t>(key)).wire_type() != expected_type) {
177     return Status::FailedPrecondition();
178   }
179 
180   // Advance past the key.
181   proto_ = proto_.subspan(bytes_read);
182   return OkStatus();
183 }
184 
ReadVarint(uint64_t * out)185 Status Decoder::ReadVarint(uint64_t* out) {
186   if (Status status = ConsumeKey(WireType::kVarint); !status.ok()) {
187     return status;
188   }
189 
190   size_t bytes_read = varint::Decode(proto_, out);
191   if (bytes_read == 0) {
192     return Status::DataLoss();
193   }
194 
195   // Advance to the next field.
196   proto_ = proto_.subspan(bytes_read);
197   previous_field_consumed_ = true;
198   return OkStatus();
199 }
200 
ReadFixed(std::byte * out,size_t size)201 Status Decoder::ReadFixed(std::byte* out, size_t size) {
202   WireType expected_wire_type =
203       size == sizeof(uint32_t) ? WireType::kFixed32 : WireType::kFixed64;
204   Status status = ConsumeKey(expected_wire_type);
205   if (!status.ok()) {
206     return status;
207   }
208 
209   if (proto_.size() < size) {
210     return Status::DataLoss();
211   }
212 
213   std::memcpy(out, proto_.data(), size);
214   proto_ = proto_.subspan(size);
215   previous_field_consumed_ = true;
216 
217   return OkStatus();
218 }
219 
ReadDelimited(span<const std::byte> * out)220 Status Decoder::ReadDelimited(span<const std::byte>* out) {
221   Status status = ConsumeKey(WireType::kDelimited);
222   if (!status.ok()) {
223     return status;
224   }
225 
226   uint64_t length;
227   size_t bytes_read = varint::Decode(proto_, &length);
228   if (bytes_read == 0) {
229     return Status::DataLoss();
230   }
231 
232   proto_ = proto_.subspan(bytes_read);
233   if (proto_.size() < length) {
234     return Status::DataLoss();
235   }
236 
237   *out = proto_.first(length);
238   proto_ = proto_.subspan(length);
239   previous_field_consumed_ = true;
240 
241   return OkStatus();
242 }
243 
Decode(span<const std::byte> proto)244 Status CallbackDecoder::Decode(span<const std::byte> proto) {
245   if (handler_ == nullptr || state_ != kReady) {
246     return Status::FailedPrecondition();
247   }
248 
249   state_ = kDecodeInProgress;
250   decoder_.Reset(proto);
251 
252   // Iterate the proto, calling the handler with each field number.
253   while (state_ == kDecodeInProgress) {
254     if (Status status = decoder_.Next(); !status.ok()) {
255       if (status.IsOutOfRange()) {
256         // Reached the end of the proto.
257         break;
258       }
259 
260       // Proto data is malformed.
261       return status;
262     }
263 
264     Status status = handler_->ProcessField(*this, decoder_.FieldNumber());
265     if (!status.ok()) {
266       state_ = status.IsCancelled() ? kDecodeCancelled : kDecodeFailed;
267       return status;
268     }
269 
270     // The callback function can modify the decoder's state; check that
271     // everything is still okay.
272     if (state_ == kDecodeFailed) {
273       break;
274     }
275   }
276 
277   if (state_ != kDecodeInProgress) {
278     return Status::DataLoss();
279   }
280 
281   state_ = kReady;
282   return OkStatus();
283 }
284 
285 }  // namespace pw::protobuf
286