1 /*
2 * Copyright (C) 2021 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "src/protozero/filtering/message_filter.h"
18
19 #include "perfetto/base/logging.h"
20 #include "perfetto/protozero/proto_utils.h"
21 #include "src/protozero/filtering/string_filter.h"
22
23 namespace protozero {
24
25 namespace {
26
27 // Inline helpers to append proto fields in output. They are the equivalent of
28 // the protozero::Message::AppendXXX() fields but don't require building and
29 // maintaining a full protozero::Message object or dealing with scattered
30 // output slices.
31 // All these functions assume there is enough space in the output buffer, which
32 // should be always the case assuming that we don't end up generating more
33 // output than input.
34
AppendVarInt(uint32_t field_id,uint64_t value,uint8_t ** out)35 inline void AppendVarInt(uint32_t field_id, uint64_t value, uint8_t** out) {
36 *out = proto_utils::WriteVarInt(proto_utils::MakeTagVarInt(field_id), *out);
37 *out = proto_utils::WriteVarInt(value, *out);
38 }
39
40 // For fixed32 / fixed64.
41 template <typename INT_T /* uint32_t | uint64_t*/>
AppendFixed(uint32_t field_id,INT_T value,uint8_t ** out)42 inline void AppendFixed(uint32_t field_id, INT_T value, uint8_t** out) {
43 *out = proto_utils::WriteVarInt(proto_utils::MakeTagFixed<INT_T>(field_id),
44 *out);
45 memcpy(*out, &value, sizeof(value));
46 *out += sizeof(value);
47 }
48
49 // For length-delimited (string, bytes) fields. Note: this function appends only
50 // the proto preamble and the varint field that states the length of the payload
51 // not the payload itself.
52 // In the case of submessages, the caller needs to re-write the length at the
53 // end in the in the returned memory area.
54 // The problem here is that, because of filtering, the length of a submessage
55 // might be < original length (the original length is still an upper-bound).
56 // Returns a pair with: (1) the pointer where the final length should be written
57 // into, (2) the length of the size field.
58 // The caller must write a redundant varint to match the original size (i.e.
59 // needs to use WriteRedundantVarInt()).
AppendLenDelim(uint32_t field_id,uint32_t len,uint8_t ** out)60 inline std::pair<uint8_t*, uint32_t> AppendLenDelim(uint32_t field_id,
61 uint32_t len,
62 uint8_t** out) {
63 *out = proto_utils::WriteVarInt(proto_utils::MakeTagLengthDelimited(field_id),
64 *out);
65 uint8_t* size_field_start = *out;
66 *out = proto_utils::WriteVarInt(len, *out);
67 const size_t size_field_len = static_cast<size_t>(*out - size_field_start);
68 return std::make_pair(size_field_start, size_field_len);
69 }
70 } // namespace
71
MessageFilter(Config config)72 MessageFilter::MessageFilter(Config config) : config_(std::move(config)) {
73 // Push a state on the stack for the implicit root message.
74 stack_.emplace_back();
75 }
76
MessageFilter()77 MessageFilter::MessageFilter() : MessageFilter(Config()) {}
78
79 MessageFilter::~MessageFilter() = default;
80
LoadFilterBytecode(const void * filter_data,size_t len)81 bool MessageFilter::Config::LoadFilterBytecode(const void* filter_data,
82 size_t len) {
83 return filter_.Load(filter_data, len);
84 }
85
SetFilterRoot(std::initializer_list<uint32_t> field_ids)86 bool MessageFilter::Config::SetFilterRoot(
87 std::initializer_list<uint32_t> field_ids) {
88 uint32_t root_msg_idx = 0;
89 for (uint32_t field_id : field_ids) {
90 auto res = filter_.Query(root_msg_idx, field_id);
91 if (!res.allowed || !res.nested_msg_field())
92 return false;
93 root_msg_idx = res.nested_msg_index;
94 }
95 root_msg_index_ = root_msg_idx;
96 return true;
97 }
98
FilterMessageFragments(const InputSlice * slices,size_t num_slices)99 MessageFilter::FilteredMessage MessageFilter::FilterMessageFragments(
100 const InputSlice* slices,
101 size_t num_slices) {
102 // First compute the upper bound for the output. The filtered message cannot
103 // be > the original message.
104 uint32_t total_len = 0;
105 for (size_t i = 0; i < num_slices; ++i)
106 total_len += slices[i].len;
107 out_buf_.reset(new uint8_t[total_len]);
108 out_ = out_buf_.get();
109 out_end_ = out_ + total_len;
110
111 // Reset the parser state.
112 tokenizer_ = MessageTokenizer();
113 error_ = false;
114 stack_.clear();
115 stack_.resize(2);
116 // stack_[0] is a sentinel and should never be hit in nominal cases. If we
117 // end up there we will just keep consuming the input stream and detecting
118 // at the end, without hurting the fastpath.
119 stack_[0].in_bytes_limit = UINT32_MAX;
120 stack_[0].eat_next_bytes = UINT32_MAX;
121 // stack_[1] is the actual root message.
122 stack_[1].in_bytes_limit = total_len;
123 stack_[1].msg_index = config_.root_msg_index();
124
125 // Process the input data and write the output.
126 for (size_t slice_idx = 0; slice_idx < num_slices; ++slice_idx) {
127 const InputSlice& slice = slices[slice_idx];
128 const uint8_t* data = static_cast<const uint8_t*>(slice.data);
129 for (size_t i = 0; i < slice.len; ++i)
130 FilterOneByte(data[i]);
131 }
132
133 // Construct the output object.
134 PERFETTO_CHECK(out_ >= out_buf_.get() && out_ <= out_end_);
135 auto used_size = static_cast<size_t>(out_ - out_buf_.get());
136 FilteredMessage res{std::move(out_buf_), used_size};
137 res.error = error_;
138 if (stack_.size() != 1 || !tokenizer_.idle() ||
139 stack_[0].in_bytes != total_len) {
140 res.error = true;
141 }
142 return res;
143 }
144
FilterOneByte(uint8_t octet)145 void MessageFilter::FilterOneByte(uint8_t octet) {
146 PERFETTO_DCHECK(!stack_.empty());
147
148 auto* state = &stack_.back();
149 StackState next_state{};
150 bool push_next_state = false;
151
152 if (state->eat_next_bytes > 0) {
153 // This is the case where the previous tokenizer_.Push() call returned a
154 // length delimited message which is NOT a submessage (a string or a bytes
155 // field). We just want to consume it, and pass it through/filter strings
156 // if the field was allowed.
157 --state->eat_next_bytes;
158 if (state->action == StackState::kPassthrough) {
159 *(out_++) = octet;
160 } else if (state->action == StackState::kFilterString) {
161 *(out_++) = octet;
162 if (state->eat_next_bytes == 0) {
163 config_.string_filter().MaybeFilter(
164 reinterpret_cast<char*>(state->filter_string_ptr),
165 static_cast<size_t>(out_ - state->filter_string_ptr));
166 }
167 }
168 } else {
169 MessageTokenizer::Token token = tokenizer_.Push(octet);
170 // |token| will not be valid() in most cases and this is WAI. When pushing
171 // a varint field, only the last byte yields a token, all the other bytes
172 // return an invalid token, they just update the internal tokenizer state.
173 if (token.valid()) {
174 auto filter = config_.filter().Query(state->msg_index, token.field_id);
175 switch (token.type) {
176 case proto_utils::ProtoWireType::kVarInt:
177 if (filter.allowed && filter.simple_field())
178 AppendVarInt(token.field_id, token.value, &out_);
179 break;
180 case proto_utils::ProtoWireType::kFixed32:
181 if (filter.allowed && filter.simple_field())
182 AppendFixed(token.field_id, static_cast<uint32_t>(token.value),
183 &out_);
184 break;
185 case proto_utils::ProtoWireType::kFixed64:
186 if (filter.allowed && filter.simple_field())
187 AppendFixed(token.field_id, static_cast<uint64_t>(token.value),
188 &out_);
189 break;
190 case proto_utils::ProtoWireType::kLengthDelimited:
191 // Here we have two cases:
192 // A. A simple string/bytes field: we just want to consume the next
193 // bytes (the string payload), optionally passing them through in
194 // output if the field is allowed.
195 // B. This is a nested submessage. In this case we want to recurse and
196 // push a new state on the stack.
197 // Note that we can't tell the difference between a
198 // "non-allowed string" and a "non-allowed submessage". But it doesn't
199 // matter because in both cases we just want to skip the next N bytes.
200 const auto submessage_len = static_cast<uint32_t>(token.value);
201 auto in_bytes_left = state->in_bytes_limit - state->in_bytes - 1;
202 if (PERFETTO_UNLIKELY(submessage_len > in_bytes_left)) {
203 // This is a malicious / malformed string/bytes/submessage that
204 // claims to be larger than the outer message that contains it.
205 return SetUnrecoverableErrorState();
206 }
207
208 if (filter.allowed && filter.nested_msg_field() &&
209 submessage_len > 0) {
210 // submessage_len == 0 is the edge case of a message with a 0-len
211 // (but present) submessage. In this case, if allowed, we don't want
212 // to push any further state (doing so would desync the FSM) but we
213 // still want to emit it.
214 // At this point |submessage_len| is only an upper bound. The
215 // final message written in output can be <= the one in input,
216 // only some of its fields might be allowed (also remember that
217 // this class implicitly removes redundancy varint encoding of
218 // len-delimited field lengths). The final length varint (the
219 // return value of AppendLenDelim()) will be filled when popping
220 // from |stack_|.
221 auto size_field =
222 AppendLenDelim(token.field_id, submessage_len, &out_);
223 push_next_state = true;
224 next_state.field_id = token.field_id;
225 next_state.msg_index = filter.nested_msg_index;
226 next_state.in_bytes_limit = submessage_len;
227 next_state.size_field = size_field.first;
228 next_state.size_field_len = size_field.second;
229 next_state.out_bytes_written_at_start = out_written();
230 } else {
231 // A string or bytes field, or a 0 length submessage.
232 state->eat_next_bytes = submessage_len;
233 if (filter.allowed && filter.filter_string_field()) {
234 state->action = StackState::kFilterString;
235 AppendLenDelim(token.field_id, submessage_len, &out_);
236 state->filter_string_ptr = out_;
237 } else if (filter.allowed) {
238 state->action = StackState::kPassthrough;
239 AppendLenDelim(token.field_id, submessage_len, &out_);
240 } else {
241 state->action = StackState::kDrop;
242 }
243 }
244 break;
245 } // switch(type)
246
247 if (PERFETTO_UNLIKELY(track_field_usage_)) {
248 IncrementCurrentFieldUsage(token.field_id, filter.allowed);
249 }
250 } // if (token.valid)
251 } // if (eat_next_bytes == 0)
252
253 ++state->in_bytes;
254 while (state->in_bytes >= state->in_bytes_limit) {
255 PERFETTO_DCHECK(state->in_bytes == state->in_bytes_limit);
256 push_next_state = false;
257
258 // We can't possibly write more than we read.
259 const uint32_t msg_bytes_written = static_cast<uint32_t>(
260 out_written() - state->out_bytes_written_at_start);
261 PERFETTO_DCHECK(msg_bytes_written <= state->in_bytes_limit);
262
263 // Backfill the length field of the
264 proto_utils::WriteRedundantVarInt(msg_bytes_written, state->size_field,
265 state->size_field_len);
266
267 const uint32_t in_bytes_processes_for_last_msg = state->in_bytes;
268 stack_.pop_back();
269 PERFETTO_CHECK(!stack_.empty());
270 state = &stack_.back();
271 state->in_bytes += in_bytes_processes_for_last_msg;
272 if (PERFETTO_UNLIKELY(!tokenizer_.idle())) {
273 // If we hit this case, it means that we got to the end of a submessage
274 // while decoding a field. We can't recover from this and we don't want to
275 // propagate a broken sub-message.
276 return SetUnrecoverableErrorState();
277 }
278 }
279
280 if (push_next_state) {
281 PERFETTO_DCHECK(tokenizer_.idle());
282 stack_.emplace_back(std::move(next_state));
283 state = &stack_.back();
284 }
285 }
286
SetUnrecoverableErrorState()287 void MessageFilter::SetUnrecoverableErrorState() {
288 error_ = true;
289 stack_.clear();
290 stack_.resize(1);
291 auto& state = stack_[0];
292 state.eat_next_bytes = UINT32_MAX;
293 state.in_bytes_limit = UINT32_MAX;
294 state.action = StackState::kDrop;
295 out_ = out_buf_.get(); // Reset the write pointer.
296 }
297
IncrementCurrentFieldUsage(uint32_t field_id,bool allowed)298 void MessageFilter::IncrementCurrentFieldUsage(uint32_t field_id,
299 bool allowed) {
300 // Slowpath. Used mainly in offline tools and tests to workout used fields in
301 // a proto.
302 PERFETTO_DCHECK(track_field_usage_);
303
304 // Field path contains a concatenation of varints, one for each nesting level.
305 // e.g. y in message Root { Sub x = 2; }; message Sub { SubSub y = 7; }
306 // is encoded as [varint(2) + varint(7)].
307 // We use varint to take the most out of SSO (small string opt). In most cases
308 // the path will fit in the on-stack 22 bytes, requiring no heap.
309 std::string field_path;
310
311 auto append_field_id = [&field_path](uint32_t id) {
312 uint8_t buf[10];
313 uint8_t* end = proto_utils::WriteVarInt(id, buf);
314 field_path.append(reinterpret_cast<char*>(buf),
315 static_cast<size_t>(end - buf));
316 };
317
318 // Append all the ancestors IDs from the state stack.
319 // The first entry of the stack has always ID 0 and we skip it (we don't know
320 // the ID of the root message itself).
321 PERFETTO_DCHECK(stack_.size() >= 2 && stack_[1].field_id == 0);
322 for (size_t i = 2; i < stack_.size(); ++i)
323 append_field_id(stack_[i].field_id);
324 // Append the id of the field in the current message.
325 append_field_id(field_id);
326 field_usage_[field_path] += allowed ? 1 : -1;
327 }
328
329 } // namespace protozero
330