xref: /aosp_15_r20/external/zucchini/heuristic_ensemble_matcher.cc (revision a03ca8b91e029cd15055c20c78c2e087c84792e4)
1 // Copyright 2017 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "components/zucchini/heuristic_ensemble_matcher.h"
6 
7 #include <algorithm>
8 #include <memory>
9 #include <string>
10 #include <utility>
11 #include <vector>
12 
13 #include "base/bind.h"
14 #include "base/logging.h"
15 #include "base/numerics/safe_conversions.h"
16 #include "base/strings/stringprintf.h"
17 #include "components/zucchini/binary_data_histogram.h"
18 #include "components/zucchini/element_detection.h"
19 #include "components/zucchini/image_utils.h"
20 #include "components/zucchini/io_utils.h"
21 
22 namespace zucchini {
23 
24 namespace {
25 
26 /******** Helper Functions ********/
27 
28 // Uses |detector| to find embedded executables inside |image|, and returns the
29 // result on success, or std::nullopt on failure,  which occurs if too many (>
30 // |kElementLimit|) elements are found.
FindEmbeddedElements(ConstBufferView image,const std::string & name,ElementDetector && detector)31 std::optional<std::vector<Element>> FindEmbeddedElements(
32     ConstBufferView image,
33     const std::string& name,
34     ElementDetector&& detector) {
35   // Maximum number of Elements in a file. This is enforced because our matching
36   // algorithm is O(n^2), which suffices for regular archive files that should
37   // have up to 10's of executable files. An archive containing 100's of
38   // executables is likely pathological, and is rejected to prevent exploits.
39   static constexpr size_t kElementLimit = 256;
40   std::vector<Element> elements;
41   ElementFinder element_finder(image, std::move(detector));
42   for (auto element = element_finder.GetNext();
43        element.has_value() && elements.size() <= kElementLimit;
44        element = element_finder.GetNext()) {
45     elements.push_back(*element);
46   }
47   if (elements.size() >= kElementLimit) {
48     LOG(WARNING) << name << ": Found too many elements.";
49     return std::nullopt;
50   }
51   LOG(INFO) << name << ": Found " << elements.size() << " elements.";
52   return elements;
53 }
54 
55 // Determines whether a proposed comparison between Elements should be rejected
56 // early, to decrease the likelihood of creating false-positive matches, which
57 // may be costly for patching. Our heuristic simply prohibits big difference in
58 // size (relative and absolute) between matched elements.
UnsafeDifference(const Element & old_element,const Element & new_element)59 bool UnsafeDifference(const Element& old_element, const Element& new_element) {
60   static constexpr double kMaxBloat = 2.0;
61   static constexpr size_t kMinWorrysomeDifference = 2 << 20;  // 2MB
62   size_t lo_size = std::min(old_element.size, new_element.size);
63   size_t hi_size = std::max(old_element.size, new_element.size);
64   if (hi_size - lo_size < kMinWorrysomeDifference)
65     return false;
66   if (hi_size < lo_size * kMaxBloat)
67     return false;
68   return true;
69 }
70 
operator <<(std::ostream & stream,const Element & elt)71 std::ostream& operator<<(std::ostream& stream, const Element& elt) {
72   stream << "(" << CastExecutableTypeToString(elt.exe_type) << ", "
73          << AsHex<8, size_t>(elt.offset) << " +" << AsHex<8, size_t>(elt.size)
74          << ")";
75   return stream;
76 }
77 
78 /******** MatchingInfoOut ********/
79 
80 // A class to output detailed information during ensemble matching. Extracting
81 // the functionality to a separate class decouples formatting and printing logic
82 // from matching logic. The base class consists of stubs.
83 class MatchingInfoOut {
84  protected:
85   MatchingInfoOut() = default;
86   MatchingInfoOut(const MatchingInfoOut&) = delete;
87   const MatchingInfoOut& operator=(const MatchingInfoOut&) = delete;
88 
89  public:
90   virtual ~MatchingInfoOut() = default;
InitSizes(size_t old_size,size_t new_size)91   virtual void InitSizes(size_t old_size, size_t new_size) {}
DeclareTypeMismatch(int iold,int inew)92   virtual void DeclareTypeMismatch(int iold, int inew) {}
DeclareUnsafeDistance(int iold,int inew)93   virtual void DeclareUnsafeDistance(int iold, int inew) {}
DeclareCandidate(int iold,int inew)94   virtual void DeclareCandidate(int iold, int inew) {}
DeclareMatch(int iold,int inew,double dist,bool is_identical)95   virtual void DeclareMatch(int iold,
96                             int inew,
97                             double dist,
98                             bool is_identical) {}
DeclareOutlier(int iold,int inew)99   virtual void DeclareOutlier(int iold, int inew) {}
100 
OutputCompare(const Element & old_element,const Element & new_element,double dist)101   virtual void OutputCompare(const Element& old_element,
102                              const Element& new_element,
103                              double dist) {}
104 
OutputMatch(const Element & best_old_element,const Element & new_element,bool is_identical,double best_dist)105   virtual void OutputMatch(const Element& best_old_element,
106                            const Element& new_element,
107                            bool is_identical,
108                            double best_dist) {}
109 
OutputScores(const std::string & stats)110   virtual void OutputScores(const std::string& stats) {}
111 
OutputTextGrid()112   virtual void OutputTextGrid() {}
113 };
114 
115 /******** MatchingInfoTerse ********/
116 
117 // A terse MatchingInfoOut that prints only basic information, using LOG().
118 class MatchingInfoOutTerse : public MatchingInfoOut {
119  public:
120   MatchingInfoOutTerse() = default;
121   MatchingInfoOutTerse(const MatchingInfoOutTerse&) = delete;
122   const MatchingInfoOutTerse& operator=(const MatchingInfoOutTerse&) = delete;
123   ~MatchingInfoOutTerse() override = default;
124 
OutputScores(const std::string & stats)125   void OutputScores(const std::string& stats) override {
126     LOG(INFO) << "Best dists: " << stats;
127   }
128 };
129 
130 /******** MatchingInfoOutVerbose ********/
131 
132 // A verbose MatchingInfoOut that prints detailed information using |out_|,
133 // including comparison pairs, scores, and a text grid representation of
134 // pairwise matching results.
135 class MatchingInfoOutVerbose : public MatchingInfoOut {
136  public:
MatchingInfoOutVerbose(std::ostream & out)137   explicit MatchingInfoOutVerbose(std::ostream& out) : out_(out) {}
138   MatchingInfoOutVerbose(const MatchingInfoOutVerbose&) = delete;
139   const MatchingInfoOutVerbose& operator=(const MatchingInfoOutVerbose&) =
140       delete;
141   ~MatchingInfoOutVerbose() override = default;
142 
143   // Outputs sizes and initializes |text_grid_|.
InitSizes(size_t old_size,size_t new_size)144   void InitSizes(size_t old_size, size_t new_size) override {
145     out_ << "Comparing old (" << old_size << " elements) and new (" << new_size
146          << " elements)" << std::endl;
147     text_grid_.assign(new_size, std::string(old_size, '-'));
148     best_dist_.assign(new_size, -1.0);
149   }
150 
151   // Functions to update match status in text grid representation.
152 
DeclareTypeMismatch(int iold,int inew)153   void DeclareTypeMismatch(int iold, int inew) override {
154     text_grid_[inew][iold] = 'T';
155   }
DeclareUnsafeDistance(int iold,int inew)156   void DeclareUnsafeDistance(int iold, int inew) override {
157     text_grid_[inew][iold] = 'U';
158   }
DeclareCandidate(int iold,int inew)159   void DeclareCandidate(int iold, int inew) override {
160     text_grid_[inew][iold] = 'C';  // Provisional.
161   }
DeclareMatch(int iold,int inew,double dist,bool is_identical)162   void DeclareMatch(int iold,
163                     int inew,
164                     double dist,
165                     bool is_identical) override {
166     text_grid_[inew][iold] = is_identical ? 'I' : 'M';
167     best_dist_[inew] = dist;
168   }
DeclareOutlier(int iold,int inew)169   void DeclareOutlier(int iold, int inew) override {
170     text_grid_[inew][iold] = 'O';
171   }
172 
173   // Functions to print detailed information.
174 
OutputCompare(const Element & old_element,const Element & new_element,double dist)175   void OutputCompare(const Element& old_element,
176                      const Element& new_element,
177                      double dist) override {
178     out_ << "Compare old" << old_element << " to new" << new_element << " --> "
179          << base::StringPrintf("%.5f", dist) << std::endl;
180   }
181 
OutputMatch(const Element & best_old_element,const Element & new_element,bool is_identical,double best_dist)182   void OutputMatch(const Element& best_old_element,
183                    const Element& new_element,
184                    bool is_identical,
185                    double best_dist) override {
186     if (is_identical) {
187       out_ << "Skipped old" << best_old_element << " - identical to new"
188            << new_element;
189     } else {
190       out_ << "Matched old" << best_old_element << " to new" << new_element
191            << " --> " << base::StringPrintf("%.5f", best_dist);
192     }
193     out_ << std::endl;
194   }
195 
OutputScores(const std::string & stats)196   void OutputScores(const std::string& stats) override {
197     out_ << "Best dists: " << stats << std::endl;
198   }
199 
OutputTextGrid()200   void OutputTextGrid() override {
201     int new_size = static_cast<int>(text_grid_.size());
202     for (int inew = 0; inew < new_size; ++inew) {
203       const std::string& line = text_grid_[inew];
204       out_ << "  ";
205       for (char ch : line) {
206         char prefix = (ch == 'I' || ch == 'M') ? '(' : ' ';
207         char suffix = (ch == 'I' || ch == 'M') ? ')' : ' ';
208         out_ << prefix << ch << suffix;
209       }
210       if (best_dist_[inew] >= 0)
211         out_ << "   " << base::StringPrintf("%.5f", best_dist_[inew]);
212       out_ << std::endl;
213     }
214     if (!text_grid_.empty()) {
215       out_ << "  Legend: I = identical, M = matched, T = type mismatch, "
216               "U = unsafe distance, C = candidate, O = outlier, - = skipped."
217            << std::endl;
218     }
219   }
220 
221  private:
222   std::ostream& out_;
223 
224   // Text grid representation of matches. Rows correspond to "old" and columns
225   // correspond to "new".
226   std::vector<std::string> text_grid_;
227 
228   // For each "new" element, distance of best match. -1 denotes no match.
229   std::vector<double> best_dist_;
230 };
231 
232 }  // namespace
233 
234 /******** HeuristicEnsembleMatcher ********/
235 
HeuristicEnsembleMatcher(std::ostream * out)236 HeuristicEnsembleMatcher::HeuristicEnsembleMatcher(std::ostream* out)
237     : out_(out) {}
238 
239 HeuristicEnsembleMatcher::~HeuristicEnsembleMatcher() = default;
240 
RunMatch(ConstBufferView old_image,ConstBufferView new_image)241 bool HeuristicEnsembleMatcher::RunMatch(ConstBufferView old_image,
242                                         ConstBufferView new_image) {
243   DCHECK(matches_.empty());
244   LOG(INFO) << "Start matching.";
245 
246   // Find all elements in "old" and "new".
247   std::optional<std::vector<Element>> old_elements =
248       FindEmbeddedElements(old_image, "Old file",
249                            base::BindRepeating(DetectElementFromDisassembler));
250   if (!old_elements.has_value())
251     return false;
252   std::optional<std::vector<Element>> new_elements =
253       FindEmbeddedElements(new_image, "New file",
254                            base::BindRepeating(DetectElementFromDisassembler));
255   if (!new_elements.has_value())
256     return false;
257 
258   std::unique_ptr<MatchingInfoOut> info_out;
259   if (out_)
260     info_out = std::make_unique<MatchingInfoOutVerbose>(*out_);
261   else
262     info_out = std::make_unique<MatchingInfoOutTerse>();
263 
264   const int num_new_elements = base::checked_cast<int>(new_elements->size());
265   const int num_old_elements = base::checked_cast<int>(old_elements->size());
266   info_out->InitSizes(num_old_elements, num_new_elements);
267 
268   // For each "new" element, match it with the "old" element that's nearest to
269   // it, with distance determined by BinaryDataHistogram. The resulting
270   // "old"-"new" pairs are stored into |results|. Possibilities:
271   // - Type mismatch: No match.
272   // - UnsafeDifference() heuristics fail: No match.
273   // - Identical match: Skip "new" since this is a trivial case.
274   // - Non-identical match: Match "new" with "old" with min distance.
275   // - No match: Skip "new".
276   struct Results {
277     int iold;
278     int inew;
279     double dist;
280   };
281   std::vector<Results> results;
282 
283   // Precompute histograms for "old" since they get reused.
284   std::vector<BinaryDataHistogram> old_his(num_old_elements);
285   for (int iold = 0; iold < num_old_elements; ++iold) {
286     ConstBufferView sub_image(old_image[(*old_elements)[iold]]);
287     old_his[iold].Compute(sub_image);
288     // ProgramDetector should have imposed minimal size limit to |sub_image|.
289     // Therefore resulting histogram are expected to be valid.
290     CHECK(old_his[iold].IsValid());
291   }
292 
293   const int kUninitIold = num_old_elements;
294   for (int inew = 0; inew < num_new_elements; ++inew) {
295     const Element& cur_new_element = (*new_elements)[inew];
296     ConstBufferView cur_new_sub_image(new_image[cur_new_element.region()]);
297     BinaryDataHistogram new_his;
298     new_his.Compute(cur_new_sub_image);
299     CHECK(new_his.IsValid());
300 
301     double best_dist = HUGE_VAL;
302     int best_iold = kUninitIold;
303     bool is_identical = false;
304 
305     for (int iold = 0; iold < num_old_elements; ++iold) {
306       const Element& cur_old_element = (*old_elements)[iold];
307       if (cur_old_element.exe_type != cur_new_element.exe_type) {
308         info_out->DeclareTypeMismatch(iold, inew);
309         continue;
310       }
311       if (UnsafeDifference(cur_old_element, cur_new_element)) {
312         info_out->DeclareUnsafeDistance(iold, inew);
313         continue;
314       }
315       double dist = old_his[iold].Distance(new_his);
316       info_out->DeclareCandidate(iold, inew);
317       info_out->OutputCompare(cur_old_element, cur_new_element, dist);
318       if (best_dist > dist) {  // Tie resolution: First-one, first-serve.
319         best_iold = iold;
320         best_dist = dist;
321         if (best_dist == 0) {
322           ConstBufferView sub_image(old_image[cur_old_element.region()]);
323           if (sub_image.equals(cur_new_sub_image)) {
324             is_identical = true;
325             break;
326           }
327         }
328       }
329     }
330 
331     if (best_iold != kUninitIold) {
332       const Element& best_old_element = (*old_elements)[best_iold];
333       info_out->DeclareMatch(best_iold, inew, best_dist, is_identical);
334       if (is_identical)  // Skip "new" if identical match is found.
335         ++num_identical_;
336       else
337         results.push_back({best_iold, inew, best_dist});
338       info_out->OutputMatch(best_old_element, cur_new_element, is_identical,
339                             best_dist);
340     }
341   }
342 
343   // Populate |matches_| from |result|. To reduce that chance of false-positive
344   // matches, statistics on dists are computed. If a match's |dist| is an
345   // outlier then it is rejected.
346   if (results.size() > 0) {
347     OutlierDetector detector;
348     for (const auto& result : results) {
349       if (result.dist > 0)
350         detector.Add(result.dist);
351     }
352     detector.Prepare();
353     info_out->OutputScores(detector.RenderStats());
354     for (const Results& result : results) {
355       if (detector.DecideOutlier(result.dist) > 0) {
356         info_out->DeclareOutlier(result.iold, result.inew);
357       } else {
358         matches_.push_back(
359             {(*old_elements)[result.iold], (*new_elements)[result.inew]});
360       }
361     }
362     info_out->OutputTextGrid();
363   }
364 
365   Trim();
366   return true;
367 }
368 
369 }  // namespace zucchini
370