1 // Copyright 2019 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #ifndef UI_EVENTS_OZONE_EVDEV_TOUCH_FILTER_NEURAL_STYLUS_PALM_DETECTION_FILTER_H_
6 #define UI_EVENTS_OZONE_EVDEV_TOUCH_FILTER_NEURAL_STYLUS_PALM_DETECTION_FILTER_H_
7
8 #include <bitset>
9 #include <cstdint>
10 #include <map>
11 #include <memory>
12 #include <string>
13 #include <unordered_set>
14 #include <utility>
15 #include <vector>
16
17 #include "base/time/time.h"
18 #if !defined(__ANDROID__) && !defined(__ANDROID_HOST__)
19 #include "ui/events/ozone/evdev/event_device_info.h"
20 #endif
21 #include "ui/events/ozone/evdev/touch_evdev_types.h"
22 #include "ui/events/ozone/evdev/touch_filter/neural_stylus_palm_detection_filter_model.h"
23 #include "ui/events/ozone/evdev/touch_filter/neural_stylus_palm_detection_filter_util.h"
24 #include "ui/events/ozone/evdev/touch_filter/palm_detection_filter.h"
25 #include "ui/events/ozone/evdev/touch_filter/shared_palm_detection_filter_state.h"
26 #include "ui/gfx/geometry/point_f.h"
27
28 namespace ui {
29
30 #if defined(__ANDROID__) || defined(__ANDROID_HOST__)
31 const base::TimeDelta kResamplePeriod = base::Milliseconds(8);
32 #endif
33
34 template <typename K, typename V>
35 std::ostream& operator<<(std::ostream& out, const std::map<K, V>& map) {
36 for (const auto& [k, v] : map) {
37 out << k << " : " << v << "\n";
38 }
39 return out;
40 }
41
42 template <typename T>
43 std::ostream& operator<<(std::ostream& out, const std::unordered_set<T>& set) {
44 out << "{";
45 for (const auto& entry : set) {
46 out << entry << ", ";
47 }
48 out << "}";
49 return out;
50 }
51
52 // An implementation of PalmDetectionFilter that relies on a DNN implementation
53 // to decide on palm detection. Requires a configured model as an argument.
54 // Heuristics are added for handling short strokes
COMPONENT_EXPORT(EVDEV)55 class COMPONENT_EXPORT(EVDEV) NeuralStylusPalmDetectionFilter
56 : public PalmDetectionFilter {
57 public:
58 // Takes ownership of the model.
59 NeuralStylusPalmDetectionFilter(
60 #if !defined(__ANDROID__) && !defined(__ANDROID_HOST__)
61 const EventDeviceInfo& devinfo,
62 #else
63 PalmFilterDeviceInfo palm_filter_device_info,
64 #endif
65 std::unique_ptr<NeuralStylusPalmDetectionFilterModel> palm_model,
66 SharedPalmDetectionFilterState* shared_palm_state);
67
68 NeuralStylusPalmDetectionFilter(const NeuralStylusPalmDetectionFilter&) =
69 delete;
70 NeuralStylusPalmDetectionFilter& operator=(
71 const NeuralStylusPalmDetectionFilter&) = delete;
72
73 ~NeuralStylusPalmDetectionFilter() override;
74 void Filter(const std::vector<InProgressTouchEvdev>& touches,
75 base::TimeTicks time,
76 std::bitset<kNumTouchEvdevSlots>* slots_to_hold,
77 std::bitset<kNumTouchEvdevSlots>* slots_to_suppress) override;
78 #if !defined(__ANDROID__) && !defined(__ANDROID_HOST__)
79 static bool CompatibleWithNeuralStylusPalmDetectionFilter(
80 const EventDeviceInfo& devinfo);
81
82 static bool CompatibleWithNeuralStylusPalmDetectionFilter(
83 const EventDeviceInfo& devinfo,
84 const std::string& ozone_params_switch_string);
85 #endif
86 static const int kFeaturesPerSample;
87 static const int kExtraFeaturesForNeighbor;
88
89 static const char kFilterName[];
90 std::string FilterNameForTesting() const override;
91
92 private:
93 void FindNearestNeighborsWithin(
94 int neighbor_count,
95 unsigned long neighbor_min_sample_count,
96 float max_distance,
97 const PalmFilterStroke& stroke,
98 std::vector<std::pair<float, int>>* nearest_strokes) const;
99 void FindBiggestNeighborsWithin(
100 int neighbor_count,
101 unsigned long neighbor_min_sample_count,
102 float max_distance,
103 const PalmFilterStroke& stroke,
104 std::vector<std::pair<float, int>>* biggest_strokes) const;
105
106 bool DetectSpuriousStroke(const std::vector<float>& features,
107 float threshold) const;
108 // Extracts the feature vector for the specified stroke.
109 std::vector<float> ExtractFeatures(int tracking_id) const;
110 void AppendFeatures(const PalmFilterStroke& stroke,
111 std::vector<float>* features) const;
112 void AppendResampledFeatures(const PalmFilterStroke& stroke,
113 std::vector<float>* features) const;
114 void AppendFeaturesAsNeighbor(const PalmFilterStroke& stroke,
115 float distance,
116 std::vector<float>* features) const;
117
118 bool ShouldDecideStroke(const PalmFilterStroke& stroke) const;
119 bool IsHeuristicPalmStroke(const PalmFilterStroke& stroke) const;
120 void EraseOldStrokes(base::TimeTicks time);
121
122 std::bitset<kNumTouchEvdevSlots> is_palm_;
123 std::bitset<kNumTouchEvdevSlots> is_delay_;
124 std::map<int, PalmFilterStroke> strokes_;
125 base::TimeTicks previous_report_time_;
126 std::unordered_set<int> active_tracking_ids_;
127 int tracking_ids_count_within_session_;
128 int tracking_ids_[kNumTouchEvdevSlots];
129 const PalmFilterDeviceInfo palm_filter_dev_info_;
130 std::unique_ptr<NeuralStylusPalmDetectionFilterModel> model_;
131
132 friend std::ostream& operator<<(
133 std::ostream& out,
134 const NeuralStylusPalmDetectionFilter& filter);
135 };
136
137 std::ostream& operator<<(std::ostream& out,
138 const NeuralStylusPalmDetectionFilter& filter);
139
140 } // namespace ui
141
142 #endif // UI_EVENTS_OZONE_EVDEV_TOUCH_FILTER_NEURAL_STYLUS_PALM_DETECTION_FILTER_H_
143