xref: /aosp_15_r20/frameworks/native/include/input/TfLiteMotionPredictor.h (revision 38e8c45f13ce32b0dcecb25141ffecaf386fa17f)
1 /*
2  * Copyright (C) 2023 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #pragma once
18 
19 #include <array>
20 #include <cstddef>
21 #include <cstdint>
22 #include <memory>
23 #include <optional>
24 #include <span>
25 
26 #include <android-base/mapped_file.h>
27 #include <input/RingBuffer.h>
28 #include <utils/Timers.h>
29 
30 #include <tensorflow/lite/core/api/error_reporter.h>
31 #include <tensorflow/lite/interpreter.h>
32 #include <tensorflow/lite/model.h>
33 #include <tensorflow/lite/signature_runner.h>
34 
35 namespace android {
36 
37 struct TfLiteMotionPredictorSample {
38     // The untransformed AMOTION_EVENT_AXIS_X and AMOTION_EVENT_AXIS_Y of the sample.
39     struct Point {
40         float x;
41         float y;
42     } position;
43     // The AMOTION_EVENT_AXIS_PRESSURE, _TILT, and _ORIENTATION.
44     float pressure;
45     float tilt;
46     float orientation;
47 };
48 
49 inline TfLiteMotionPredictorSample::Point operator-(const TfLiteMotionPredictorSample::Point& lhs,
50                                                     const TfLiteMotionPredictorSample::Point& rhs) {
51     return {.x = lhs.x - rhs.x, .y = lhs.y - rhs.y};
52 }
53 
54 class TfLiteMotionPredictorModel;
55 
56 // Buffer storage for a TfLiteMotionPredictorModel.
57 class TfLiteMotionPredictorBuffers {
58 public:
59     // Creates buffer storage for a model with the given input length.
60     TfLiteMotionPredictorBuffers(size_t inputLength);
61 
62     // Adds a motion sample to the buffers.
63     void pushSample(int64_t timestamp, TfLiteMotionPredictorSample sample);
64 
65     // Returns true if the buffers are complete enough to generate a prediction.
isReady()66     bool isReady() const {
67         // Predictions can't be applied unless there are at least two points to determine
68         // the direction to apply them in.
69         return mAxisFrom && mAxisTo;
70     }
71 
72     // Resets all buffers to their initial state.
73     void reset();
74 
75     // Copies the buffers to those of a model for prediction.
76     void copyTo(TfLiteMotionPredictorModel& model) const;
77 
78     // Returns the current axis of the buffer's samples. Only valid if isReady().
axisFrom()79     TfLiteMotionPredictorSample axisFrom() const { return *mAxisFrom; }
axisTo()80     TfLiteMotionPredictorSample axisTo() const { return *mAxisTo; }
81 
82     // Returns the timestamp of the last sample.
lastTimestamp()83     int64_t lastTimestamp() const { return mTimestamp; }
84 
85 private:
86     int64_t mTimestamp = 0;
87 
88     RingBuffer<float> mInputR;
89     RingBuffer<float> mInputPhi;
90     RingBuffer<float> mInputPressure;
91     RingBuffer<float> mInputTilt;
92     RingBuffer<float> mInputOrientation;
93 
94     // The samples defining the current polar axis.
95     std::optional<TfLiteMotionPredictorSample> mAxisFrom;
96     std::optional<TfLiteMotionPredictorSample> mAxisTo;
97 };
98 
99 // A TFLite model for generating motion predictions.
100 class TfLiteMotionPredictorModel {
101 public:
102     struct Config {
103         // The time between predictions.
104         nsecs_t predictionInterval = 0;
105         // The noise floor for predictions.
106         // Distances (r) less than this should be discarded as noise.
107         float distanceNoiseFloor = 0;
108 
109         // Low and high jerk thresholds (with normalized dt = 1) for predictions.
110         // High jerk means more predictions will be pruned, vice versa for low.
111         float lowJerk = 0;
112         float highJerk = 0;
113 
114         // Coefficient for the first-order IIR filter for jerk calculation.
115         float jerkAlpha = 1;
116     };
117 
118     // Creates a model from an encoded Flatbuffer model.
119     static std::unique_ptr<TfLiteMotionPredictorModel> create();
120 
121     ~TfLiteMotionPredictorModel();
122 
123     // Returns the length of the model's input buffers.
124     size_t inputLength() const;
125 
126     // Returns the length of the model's output buffers.
127     size_t outputLength() const;
128 
config()129     const Config& config() const { return mConfig; }
130 
131     // Executes the model.
132     // Returns true if the model successfully executed and the output tensors can be read.
133     bool invoke();
134 
135     // Returns mutable buffers to the input tensors of inputLength() elements.
136     std::span<float> inputR();
137     std::span<float> inputPhi();
138     std::span<float> inputPressure();
139     std::span<float> inputOrientation();
140     std::span<float> inputTilt();
141 
142     // Returns immutable buffers to the output tensors of identical length. Only valid after a
143     // successful call to invoke().
144     std::span<const float> outputR() const;
145     std::span<const float> outputPhi() const;
146     std::span<const float> outputPressure() const;
147 
148 private:
149     explicit TfLiteMotionPredictorModel(std::unique_ptr<android::base::MappedFile> model,
150                                         Config config);
151 
152     void allocateTensors();
153     void attachInputTensors();
154     void attachOutputTensors();
155 
156     TfLiteTensor* mInputR = nullptr;
157     TfLiteTensor* mInputPhi = nullptr;
158     TfLiteTensor* mInputPressure = nullptr;
159     TfLiteTensor* mInputTilt = nullptr;
160     TfLiteTensor* mInputOrientation = nullptr;
161 
162     const TfLiteTensor* mOutputR = nullptr;
163     const TfLiteTensor* mOutputPhi = nullptr;
164     const TfLiteTensor* mOutputPressure = nullptr;
165 
166     std::unique_ptr<android::base::MappedFile> mFlatBuffer;
167     std::unique_ptr<tflite::ErrorReporter> mErrorReporter;
168     std::unique_ptr<tflite::FlatBufferModel> mModel;
169     std::unique_ptr<tflite::Interpreter> mInterpreter;
170     tflite::SignatureRunner* mRunner = nullptr;
171 
172     const Config mConfig = {};
173 };
174 
175 } // namespace android
176