xref: /aosp_15_r20/frameworks/native/include/input/TfLiteMotionPredictor.h (revision 38e8c45f13ce32b0dcecb25141ffecaf386fa17f)
1*38e8c45fSAndroid Build Coastguard Worker /*
2*38e8c45fSAndroid Build Coastguard Worker  * Copyright (C) 2023 The Android Open Source Project
3*38e8c45fSAndroid Build Coastguard Worker  *
4*38e8c45fSAndroid Build Coastguard Worker  * Licensed under the Apache License, Version 2.0 (the "License");
5*38e8c45fSAndroid Build Coastguard Worker  * you may not use this file except in compliance with the License.
6*38e8c45fSAndroid Build Coastguard Worker  * You may obtain a copy of the License at
7*38e8c45fSAndroid Build Coastguard Worker  *
8*38e8c45fSAndroid Build Coastguard Worker  *      http://www.apache.org/licenses/LICENSE-2.0
9*38e8c45fSAndroid Build Coastguard Worker  *
10*38e8c45fSAndroid Build Coastguard Worker  * Unless required by applicable law or agreed to in writing, software
11*38e8c45fSAndroid Build Coastguard Worker  * distributed under the License is distributed on an "AS IS" BASIS,
12*38e8c45fSAndroid Build Coastguard Worker  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*38e8c45fSAndroid Build Coastguard Worker  * See the License for the specific language governing permissions and
14*38e8c45fSAndroid Build Coastguard Worker  * limitations under the License.
15*38e8c45fSAndroid Build Coastguard Worker  */
16*38e8c45fSAndroid Build Coastguard Worker 
17*38e8c45fSAndroid Build Coastguard Worker #pragma once
18*38e8c45fSAndroid Build Coastguard Worker 
19*38e8c45fSAndroid Build Coastguard Worker #include <array>
20*38e8c45fSAndroid Build Coastguard Worker #include <cstddef>
21*38e8c45fSAndroid Build Coastguard Worker #include <cstdint>
22*38e8c45fSAndroid Build Coastguard Worker #include <memory>
23*38e8c45fSAndroid Build Coastguard Worker #include <optional>
24*38e8c45fSAndroid Build Coastguard Worker #include <span>
25*38e8c45fSAndroid Build Coastguard Worker 
26*38e8c45fSAndroid Build Coastguard Worker #include <android-base/mapped_file.h>
27*38e8c45fSAndroid Build Coastguard Worker #include <input/RingBuffer.h>
28*38e8c45fSAndroid Build Coastguard Worker #include <utils/Timers.h>
29*38e8c45fSAndroid Build Coastguard Worker 
30*38e8c45fSAndroid Build Coastguard Worker #include <tensorflow/lite/core/api/error_reporter.h>
31*38e8c45fSAndroid Build Coastguard Worker #include <tensorflow/lite/interpreter.h>
32*38e8c45fSAndroid Build Coastguard Worker #include <tensorflow/lite/model.h>
33*38e8c45fSAndroid Build Coastguard Worker #include <tensorflow/lite/signature_runner.h>
34*38e8c45fSAndroid Build Coastguard Worker 
35*38e8c45fSAndroid Build Coastguard Worker namespace android {
36*38e8c45fSAndroid Build Coastguard Worker 
37*38e8c45fSAndroid Build Coastguard Worker struct TfLiteMotionPredictorSample {
38*38e8c45fSAndroid Build Coastguard Worker     // The untransformed AMOTION_EVENT_AXIS_X and AMOTION_EVENT_AXIS_Y of the sample.
39*38e8c45fSAndroid Build Coastguard Worker     struct Point {
40*38e8c45fSAndroid Build Coastguard Worker         float x;
41*38e8c45fSAndroid Build Coastguard Worker         float y;
42*38e8c45fSAndroid Build Coastguard Worker     } position;
43*38e8c45fSAndroid Build Coastguard Worker     // The AMOTION_EVENT_AXIS_PRESSURE, _TILT, and _ORIENTATION.
44*38e8c45fSAndroid Build Coastguard Worker     float pressure;
45*38e8c45fSAndroid Build Coastguard Worker     float tilt;
46*38e8c45fSAndroid Build Coastguard Worker     float orientation;
47*38e8c45fSAndroid Build Coastguard Worker };
48*38e8c45fSAndroid Build Coastguard Worker 
49*38e8c45fSAndroid Build Coastguard Worker inline TfLiteMotionPredictorSample::Point operator-(const TfLiteMotionPredictorSample::Point& lhs,
50*38e8c45fSAndroid Build Coastguard Worker                                                     const TfLiteMotionPredictorSample::Point& rhs) {
51*38e8c45fSAndroid Build Coastguard Worker     return {.x = lhs.x - rhs.x, .y = lhs.y - rhs.y};
52*38e8c45fSAndroid Build Coastguard Worker }
53*38e8c45fSAndroid Build Coastguard Worker 
54*38e8c45fSAndroid Build Coastguard Worker class TfLiteMotionPredictorModel;
55*38e8c45fSAndroid Build Coastguard Worker 
56*38e8c45fSAndroid Build Coastguard Worker // Buffer storage for a TfLiteMotionPredictorModel.
57*38e8c45fSAndroid Build Coastguard Worker class TfLiteMotionPredictorBuffers {
58*38e8c45fSAndroid Build Coastguard Worker public:
59*38e8c45fSAndroid Build Coastguard Worker     // Creates buffer storage for a model with the given input length.
60*38e8c45fSAndroid Build Coastguard Worker     TfLiteMotionPredictorBuffers(size_t inputLength);
61*38e8c45fSAndroid Build Coastguard Worker 
62*38e8c45fSAndroid Build Coastguard Worker     // Adds a motion sample to the buffers.
63*38e8c45fSAndroid Build Coastguard Worker     void pushSample(int64_t timestamp, TfLiteMotionPredictorSample sample);
64*38e8c45fSAndroid Build Coastguard Worker 
65*38e8c45fSAndroid Build Coastguard Worker     // Returns true if the buffers are complete enough to generate a prediction.
isReady()66*38e8c45fSAndroid Build Coastguard Worker     bool isReady() const {
67*38e8c45fSAndroid Build Coastguard Worker         // Predictions can't be applied unless there are at least two points to determine
68*38e8c45fSAndroid Build Coastguard Worker         // the direction to apply them in.
69*38e8c45fSAndroid Build Coastguard Worker         return mAxisFrom && mAxisTo;
70*38e8c45fSAndroid Build Coastguard Worker     }
71*38e8c45fSAndroid Build Coastguard Worker 
72*38e8c45fSAndroid Build Coastguard Worker     // Resets all buffers to their initial state.
73*38e8c45fSAndroid Build Coastguard Worker     void reset();
74*38e8c45fSAndroid Build Coastguard Worker 
75*38e8c45fSAndroid Build Coastguard Worker     // Copies the buffers to those of a model for prediction.
76*38e8c45fSAndroid Build Coastguard Worker     void copyTo(TfLiteMotionPredictorModel& model) const;
77*38e8c45fSAndroid Build Coastguard Worker 
78*38e8c45fSAndroid Build Coastguard Worker     // Returns the current axis of the buffer's samples. Only valid if isReady().
axisFrom()79*38e8c45fSAndroid Build Coastguard Worker     TfLiteMotionPredictorSample axisFrom() const { return *mAxisFrom; }
axisTo()80*38e8c45fSAndroid Build Coastguard Worker     TfLiteMotionPredictorSample axisTo() const { return *mAxisTo; }
81*38e8c45fSAndroid Build Coastguard Worker 
82*38e8c45fSAndroid Build Coastguard Worker     // Returns the timestamp of the last sample.
lastTimestamp()83*38e8c45fSAndroid Build Coastguard Worker     int64_t lastTimestamp() const { return mTimestamp; }
84*38e8c45fSAndroid Build Coastguard Worker 
85*38e8c45fSAndroid Build Coastguard Worker private:
86*38e8c45fSAndroid Build Coastguard Worker     int64_t mTimestamp = 0;
87*38e8c45fSAndroid Build Coastguard Worker 
88*38e8c45fSAndroid Build Coastguard Worker     RingBuffer<float> mInputR;
89*38e8c45fSAndroid Build Coastguard Worker     RingBuffer<float> mInputPhi;
90*38e8c45fSAndroid Build Coastguard Worker     RingBuffer<float> mInputPressure;
91*38e8c45fSAndroid Build Coastguard Worker     RingBuffer<float> mInputTilt;
92*38e8c45fSAndroid Build Coastguard Worker     RingBuffer<float> mInputOrientation;
93*38e8c45fSAndroid Build Coastguard Worker 
94*38e8c45fSAndroid Build Coastguard Worker     // The samples defining the current polar axis.
95*38e8c45fSAndroid Build Coastguard Worker     std::optional<TfLiteMotionPredictorSample> mAxisFrom;
96*38e8c45fSAndroid Build Coastguard Worker     std::optional<TfLiteMotionPredictorSample> mAxisTo;
97*38e8c45fSAndroid Build Coastguard Worker };
98*38e8c45fSAndroid Build Coastguard Worker 
99*38e8c45fSAndroid Build Coastguard Worker // A TFLite model for generating motion predictions.
100*38e8c45fSAndroid Build Coastguard Worker class TfLiteMotionPredictorModel {
101*38e8c45fSAndroid Build Coastguard Worker public:
102*38e8c45fSAndroid Build Coastguard Worker     struct Config {
103*38e8c45fSAndroid Build Coastguard Worker         // The time between predictions.
104*38e8c45fSAndroid Build Coastguard Worker         nsecs_t predictionInterval = 0;
105*38e8c45fSAndroid Build Coastguard Worker         // The noise floor for predictions.
106*38e8c45fSAndroid Build Coastguard Worker         // Distances (r) less than this should be discarded as noise.
107*38e8c45fSAndroid Build Coastguard Worker         float distanceNoiseFloor = 0;
108*38e8c45fSAndroid Build Coastguard Worker 
109*38e8c45fSAndroid Build Coastguard Worker         // Low and high jerk thresholds (with normalized dt = 1) for predictions.
110*38e8c45fSAndroid Build Coastguard Worker         // High jerk means more predictions will be pruned, vice versa for low.
111*38e8c45fSAndroid Build Coastguard Worker         float lowJerk = 0;
112*38e8c45fSAndroid Build Coastguard Worker         float highJerk = 0;
113*38e8c45fSAndroid Build Coastguard Worker 
114*38e8c45fSAndroid Build Coastguard Worker         // Coefficient for the first-order IIR filter for jerk calculation.
115*38e8c45fSAndroid Build Coastguard Worker         float jerkAlpha = 1;
116*38e8c45fSAndroid Build Coastguard Worker     };
117*38e8c45fSAndroid Build Coastguard Worker 
118*38e8c45fSAndroid Build Coastguard Worker     // Creates a model from an encoded Flatbuffer model.
119*38e8c45fSAndroid Build Coastguard Worker     static std::unique_ptr<TfLiteMotionPredictorModel> create();
120*38e8c45fSAndroid Build Coastguard Worker 
121*38e8c45fSAndroid Build Coastguard Worker     ~TfLiteMotionPredictorModel();
122*38e8c45fSAndroid Build Coastguard Worker 
123*38e8c45fSAndroid Build Coastguard Worker     // Returns the length of the model's input buffers.
124*38e8c45fSAndroid Build Coastguard Worker     size_t inputLength() const;
125*38e8c45fSAndroid Build Coastguard Worker 
126*38e8c45fSAndroid Build Coastguard Worker     // Returns the length of the model's output buffers.
127*38e8c45fSAndroid Build Coastguard Worker     size_t outputLength() const;
128*38e8c45fSAndroid Build Coastguard Worker 
config()129*38e8c45fSAndroid Build Coastguard Worker     const Config& config() const { return mConfig; }
130*38e8c45fSAndroid Build Coastguard Worker 
131*38e8c45fSAndroid Build Coastguard Worker     // Executes the model.
132*38e8c45fSAndroid Build Coastguard Worker     // Returns true if the model successfully executed and the output tensors can be read.
133*38e8c45fSAndroid Build Coastguard Worker     bool invoke();
134*38e8c45fSAndroid Build Coastguard Worker 
135*38e8c45fSAndroid Build Coastguard Worker     // Returns mutable buffers to the input tensors of inputLength() elements.
136*38e8c45fSAndroid Build Coastguard Worker     std::span<float> inputR();
137*38e8c45fSAndroid Build Coastguard Worker     std::span<float> inputPhi();
138*38e8c45fSAndroid Build Coastguard Worker     std::span<float> inputPressure();
139*38e8c45fSAndroid Build Coastguard Worker     std::span<float> inputOrientation();
140*38e8c45fSAndroid Build Coastguard Worker     std::span<float> inputTilt();
141*38e8c45fSAndroid Build Coastguard Worker 
142*38e8c45fSAndroid Build Coastguard Worker     // Returns immutable buffers to the output tensors of identical length. Only valid after a
143*38e8c45fSAndroid Build Coastguard Worker     // successful call to invoke().
144*38e8c45fSAndroid Build Coastguard Worker     std::span<const float> outputR() const;
145*38e8c45fSAndroid Build Coastguard Worker     std::span<const float> outputPhi() const;
146*38e8c45fSAndroid Build Coastguard Worker     std::span<const float> outputPressure() const;
147*38e8c45fSAndroid Build Coastguard Worker 
148*38e8c45fSAndroid Build Coastguard Worker private:
149*38e8c45fSAndroid Build Coastguard Worker     explicit TfLiteMotionPredictorModel(std::unique_ptr<android::base::MappedFile> model,
150*38e8c45fSAndroid Build Coastguard Worker                                         Config config);
151*38e8c45fSAndroid Build Coastguard Worker 
152*38e8c45fSAndroid Build Coastguard Worker     void allocateTensors();
153*38e8c45fSAndroid Build Coastguard Worker     void attachInputTensors();
154*38e8c45fSAndroid Build Coastguard Worker     void attachOutputTensors();
155*38e8c45fSAndroid Build Coastguard Worker 
156*38e8c45fSAndroid Build Coastguard Worker     TfLiteTensor* mInputR = nullptr;
157*38e8c45fSAndroid Build Coastguard Worker     TfLiteTensor* mInputPhi = nullptr;
158*38e8c45fSAndroid Build Coastguard Worker     TfLiteTensor* mInputPressure = nullptr;
159*38e8c45fSAndroid Build Coastguard Worker     TfLiteTensor* mInputTilt = nullptr;
160*38e8c45fSAndroid Build Coastguard Worker     TfLiteTensor* mInputOrientation = nullptr;
161*38e8c45fSAndroid Build Coastguard Worker 
162*38e8c45fSAndroid Build Coastguard Worker     const TfLiteTensor* mOutputR = nullptr;
163*38e8c45fSAndroid Build Coastguard Worker     const TfLiteTensor* mOutputPhi = nullptr;
164*38e8c45fSAndroid Build Coastguard Worker     const TfLiteTensor* mOutputPressure = nullptr;
165*38e8c45fSAndroid Build Coastguard Worker 
166*38e8c45fSAndroid Build Coastguard Worker     std::unique_ptr<android::base::MappedFile> mFlatBuffer;
167*38e8c45fSAndroid Build Coastguard Worker     std::unique_ptr<tflite::ErrorReporter> mErrorReporter;
168*38e8c45fSAndroid Build Coastguard Worker     std::unique_ptr<tflite::FlatBufferModel> mModel;
169*38e8c45fSAndroid Build Coastguard Worker     std::unique_ptr<tflite::Interpreter> mInterpreter;
170*38e8c45fSAndroid Build Coastguard Worker     tflite::SignatureRunner* mRunner = nullptr;
171*38e8c45fSAndroid Build Coastguard Worker 
172*38e8c45fSAndroid Build Coastguard Worker     const Config mConfig = {};
173*38e8c45fSAndroid Build Coastguard Worker };
174*38e8c45fSAndroid Build Coastguard Worker 
175*38e8c45fSAndroid Build Coastguard Worker } // namespace android
176