1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "src/tracing/service/packet_stream_validator.h"
18
19 #include <stddef.h>
20
21 #include <cinttypes>
22
23 #include "perfetto/base/logging.h"
24 #include "perfetto/ext/base/utils.h"
25 #include "perfetto/protozero/proto_utils.h"
26
27 #include "protos/perfetto/trace/trace_packet.pbzero.h"
28
29 namespace perfetto {
30
31 namespace {
32
33 using protozero::proto_utils::ProtoWireType;
34
35 const uint32_t kReservedFieldIds[] = {
36 protos::pbzero::TracePacket::kTrustedUidFieldNumber,
37 protos::pbzero::TracePacket::kTrustedPacketSequenceIdFieldNumber,
38 protos::pbzero::TracePacket::kTraceConfigFieldNumber,
39 protos::pbzero::TracePacket::kTraceStatsFieldNumber,
40 protos::pbzero::TracePacket::kCompressedPacketsFieldNumber,
41 protos::pbzero::TracePacket::kSynchronizationMarkerFieldNumber,
42 protos::pbzero::TracePacket::kTrustedPidFieldNumber,
43 protos::pbzero::TracePacket::kMachineIdFieldNumber,
44 };
45
46 // This translation unit is quite subtle and perf-sensitive. Remember to check
47 // BM_PacketStreamValidator in perfetto_benchmarks when making changes.
48
49 // Checks that a packet, spread over several slices, is well-formed and doesn't
50 // contain reserved top-level fields.
51 // The checking logic is based on a state-machine that skips the fields' payload
52 // and operates as follows:
53 // +-------------------------------+ <-------------------------+
54 // +----------> | Read field preamble (varint) | <----------------------+ |
55 // | +-------------------------------+ | |
56 // | | | | | |
57 // | <Varint> <Fixed 32/64> <Length-delimited field> | |
58 // | V | V | |
59 // | +------------------+ | +--------------+ | |
60 // | | Read field value | | | Read length | | |
61 // | | (another varint) | | | (varint) | | |
62 // | +------------------+ | +--------------+ | |
63 // | | V V | |
64 // +-----------+ +----------------+ +-----------------+ | |
65 // | Skip 4/8 Bytes | | Skip $len Bytes |-------+ |
66 // +----------------+ +-----------------+ |
67 // | |
68 // +------------------------------------------+
69 class ProtoFieldParserFSM {
70 public:
71 // This method effectively continuously parses varints (either for the field
72 // preamble or the payload or the submessage length) and tells the caller
73 // (the Validate() method) how many bytes to skip until the next field.
Push(uint8_t octet)74 size_t Push(uint8_t octet) {
75 varint_ |= static_cast<uint64_t>(octet & 0x7F) << varint_shift_;
76 if (octet & 0x80) {
77 varint_shift_ += 7;
78 if (varint_shift_ >= 64) {
79 // Do not invoke UB on next call.
80 varint_shift_ = 0;
81 state_ = kInvalidVarInt;
82 }
83 return 0;
84 }
85 uint64_t varint = varint_;
86 varint_ = 0;
87 varint_shift_ = 0;
88
89 switch (state_) {
90 case kFieldPreamble: {
91 uint64_t field_type = varint & 7; // 7 = 0..0111
92 auto field_id = static_cast<uint32_t>(varint >> 3);
93 // Check if the field id is reserved, go into an error state if it is.
94 for (size_t i = 0; i < base::ArraySize(kReservedFieldIds); ++i) {
95 if (field_id == kReservedFieldIds[i]) {
96 state_ = kWroteReservedField;
97 return 0;
98 }
99 }
100 // The field type is legit, now check it's well formed and within
101 // boundaries.
102 if (field_type == static_cast<uint64_t>(ProtoWireType::kVarInt)) {
103 state_ = kVarIntValue;
104 } else if (field_type ==
105 static_cast<uint64_t>(ProtoWireType::kFixed32)) {
106 return 4;
107 } else if (field_type ==
108 static_cast<uint64_t>(ProtoWireType::kFixed64)) {
109 return 8;
110 } else if (field_type ==
111 static_cast<uint64_t>(ProtoWireType::kLengthDelimited)) {
112 state_ = kLenDelimitedLen;
113 } else {
114 state_ = kUnknownFieldType;
115 }
116 return 0;
117 }
118
119 case kVarIntValue: {
120 // Consume the int field payload and go back to the next field.
121 state_ = kFieldPreamble;
122 return 0;
123 }
124
125 case kLenDelimitedLen: {
126 if (varint > protozero::proto_utils::kMaxMessageLength) {
127 state_ = kMessageTooBig;
128 return 0;
129 }
130 state_ = kFieldPreamble;
131 return static_cast<size_t>(varint);
132 }
133
134 case kWroteReservedField:
135 case kUnknownFieldType:
136 case kMessageTooBig:
137 case kInvalidVarInt:
138 // Persistent error states.
139 return 0;
140
141 } // switch(state_)
142 return 0; // To keep GCC happy.
143 }
144
145 // Queried at the end of the all payload. A message is well-formed only
146 // if the FSM is back to the state where it should parse the next field and
147 // hasn't started parsing any preamble.
valid() const148 bool valid() const { return state_ == kFieldPreamble && varint_shift_ == 0; }
state() const149 int state() const { return static_cast<int>(state_); }
150
151 private:
152 enum State {
153 kFieldPreamble = 0, // Parsing the varint for the field preamble.
154 kVarIntValue, // Parsing the varint value for the field payload.
155 kLenDelimitedLen, // Parsing the length of the length-delimited field.
156
157 // Error states:
158 kWroteReservedField, // Tried to set a reserved field id.
159 kUnknownFieldType, // Encountered an invalid field type.
160 kMessageTooBig, // Size of the length delimited message was too big.
161 kInvalidVarInt, // VarInt larger than 64 bits.
162 };
163
164 State state_ = kFieldPreamble;
165 uint64_t varint_ = 0;
166 uint32_t varint_shift_ = 0;
167 };
168
169 } // namespace
170
171 // static
Validate(const Slices & slices)172 bool PacketStreamValidator::Validate(const Slices& slices) {
173 ProtoFieldParserFSM parser;
174 size_t skip_bytes = 0;
175 for (const Slice& slice : slices) {
176 for (size_t i = 0; i < slice.size;) {
177 const size_t skip_bytes_cur_slice = std::min(skip_bytes, slice.size - i);
178 if (skip_bytes_cur_slice > 0) {
179 i += skip_bytes_cur_slice;
180 skip_bytes -= skip_bytes_cur_slice;
181 } else {
182 uint8_t octet = *(reinterpret_cast<const uint8_t*>(slice.start) + i);
183 skip_bytes = parser.Push(octet);
184 i++;
185 }
186 }
187 }
188 if (skip_bytes == 0 && parser.valid())
189 return true;
190
191 PERFETTO_DLOG("Packet validation error (state %d, skip = %zu)",
192 parser.state(), skip_bytes);
193 return false;
194 }
195
196 } // namespace perfetto
197