1 // Copyright 2019 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #ifndef UI_EVENTS_OZONE_EVDEV_TOUCH_FILTER_NEURAL_STYLUS_PALM_DETECTION_FILTER_H_
6 #define UI_EVENTS_OZONE_EVDEV_TOUCH_FILTER_NEURAL_STYLUS_PALM_DETECTION_FILTER_H_
7 
8 #include <bitset>
9 #include <cstdint>
10 #include <map>
11 #include <memory>
12 #include <string>
13 #include <unordered_set>
14 #include <utility>
15 #include <vector>
16 
17 #include "base/time/time.h"
18 #if !defined(__ANDROID__) && !defined(__ANDROID_HOST__)
19 #include "ui/events/ozone/evdev/event_device_info.h"
20 #endif
21 #include "ui/events/ozone/evdev/touch_evdev_types.h"
22 #include "ui/events/ozone/evdev/touch_filter/neural_stylus_palm_detection_filter_model.h"
23 #include "ui/events/ozone/evdev/touch_filter/neural_stylus_palm_detection_filter_util.h"
24 #include "ui/events/ozone/evdev/touch_filter/palm_detection_filter.h"
25 #include "ui/events/ozone/evdev/touch_filter/shared_palm_detection_filter_state.h"
26 #include "ui/gfx/geometry/point_f.h"
27 
28 namespace ui {
29 
30 #if defined(__ANDROID__) || defined(__ANDROID_HOST__)
31 const base::TimeDelta kResamplePeriod = base::Milliseconds(8);
32 #endif
33 
34 template <typename K, typename V>
35 std::ostream& operator<<(std::ostream& out, const std::map<K, V>& map) {
36   for (const auto& [k, v] : map) {
37     out << k << " : " << v << "\n";
38   }
39   return out;
40 }
41 
42 template <typename T>
43 std::ostream& operator<<(std::ostream& out, const std::unordered_set<T>& set) {
44   out << "{";
45   for (const auto& entry : set) {
46     out << entry << ", ";
47   }
48   out << "}";
49   return out;
50 }
51 
52 // An implementation of PalmDetectionFilter that relies on a DNN implementation
53 // to decide on palm detection. Requires a configured model as an argument.
54 // Heuristics are added for handling short strokes
COMPONENT_EXPORT(EVDEV)55 class COMPONENT_EXPORT(EVDEV) NeuralStylusPalmDetectionFilter
56     : public PalmDetectionFilter {
57  public:
58   // Takes ownership of the model.
59   NeuralStylusPalmDetectionFilter(
60 #if !defined(__ANDROID__) && !defined(__ANDROID_HOST__)
61       const EventDeviceInfo& devinfo,
62 #else
63       PalmFilterDeviceInfo palm_filter_device_info,
64 #endif
65       std::unique_ptr<NeuralStylusPalmDetectionFilterModel> palm_model,
66       SharedPalmDetectionFilterState* shared_palm_state);
67 
68   NeuralStylusPalmDetectionFilter(const NeuralStylusPalmDetectionFilter&) =
69       delete;
70   NeuralStylusPalmDetectionFilter& operator=(
71       const NeuralStylusPalmDetectionFilter&) = delete;
72 
73   ~NeuralStylusPalmDetectionFilter() override;
74   void Filter(const std::vector<InProgressTouchEvdev>& touches,
75               base::TimeTicks time,
76               std::bitset<kNumTouchEvdevSlots>* slots_to_hold,
77               std::bitset<kNumTouchEvdevSlots>* slots_to_suppress) override;
78 #if !defined(__ANDROID__) && !defined(__ANDROID_HOST__)
79   static bool CompatibleWithNeuralStylusPalmDetectionFilter(
80       const EventDeviceInfo& devinfo);
81 
82   static bool CompatibleWithNeuralStylusPalmDetectionFilter(
83       const EventDeviceInfo& devinfo,
84       const std::string& ozone_params_switch_string);
85 #endif
86   static const int kFeaturesPerSample;
87   static const int kExtraFeaturesForNeighbor;
88 
89   static const char kFilterName[];
90   std::string FilterNameForTesting() const override;
91 
92  private:
93   void FindNearestNeighborsWithin(
94       int neighbor_count,
95       unsigned long neighbor_min_sample_count,
96       float max_distance,
97       const PalmFilterStroke& stroke,
98       std::vector<std::pair<float, int>>* nearest_strokes) const;
99   void FindBiggestNeighborsWithin(
100       int neighbor_count,
101       unsigned long neighbor_min_sample_count,
102       float max_distance,
103       const PalmFilterStroke& stroke,
104       std::vector<std::pair<float, int>>* biggest_strokes) const;
105 
106   bool DetectSpuriousStroke(const std::vector<float>& features,
107                             float threshold) const;
108   // Extracts the feature vector for the specified stroke.
109   std::vector<float> ExtractFeatures(int tracking_id) const;
110   void AppendFeatures(const PalmFilterStroke& stroke,
111                       std::vector<float>* features) const;
112   void AppendResampledFeatures(const PalmFilterStroke& stroke,
113                                std::vector<float>* features) const;
114   void AppendFeaturesAsNeighbor(const PalmFilterStroke& stroke,
115                                 float distance,
116                                 std::vector<float>* features) const;
117 
118   bool ShouldDecideStroke(const PalmFilterStroke& stroke) const;
119   bool IsHeuristicPalmStroke(const PalmFilterStroke& stroke) const;
120   void EraseOldStrokes(base::TimeTicks time);
121 
122   std::bitset<kNumTouchEvdevSlots> is_palm_;
123   std::bitset<kNumTouchEvdevSlots> is_delay_;
124   std::map<int, PalmFilterStroke> strokes_;
125   base::TimeTicks previous_report_time_;
126   std::unordered_set<int> active_tracking_ids_;
127   int tracking_ids_count_within_session_;
128   int tracking_ids_[kNumTouchEvdevSlots];
129   const PalmFilterDeviceInfo palm_filter_dev_info_;
130   std::unique_ptr<NeuralStylusPalmDetectionFilterModel> model_;
131 
132   friend std::ostream& operator<<(
133       std::ostream& out,
134       const NeuralStylusPalmDetectionFilter& filter);
135 };
136 
137 std::ostream& operator<<(std::ostream& out,
138                          const NeuralStylusPalmDetectionFilter& filter);
139 
140 }  // namespace ui
141 
142 #endif  // UI_EVENTS_OZONE_EVDEV_TOUCH_FILTER_NEURAL_STYLUS_PALM_DETECTION_FILTER_H_
143