xref: /aosp_15_r20/frameworks/av/media/libheadtracking/PosePredictor.h (revision ec779b8e0859a360c3d303172224686826e6e0e1)
1*ec779b8eSAndroid Build Coastguard Worker /*
2*ec779b8eSAndroid Build Coastguard Worker  * Copyright (C) 2023 The Android Open Source Project
3*ec779b8eSAndroid Build Coastguard Worker  *
4*ec779b8eSAndroid Build Coastguard Worker  * Licensed under the Apache License, Version 2.0 (the "License");
5*ec779b8eSAndroid Build Coastguard Worker  * you may not use this file except in compliance with the License.
6*ec779b8eSAndroid Build Coastguard Worker  * You may obtain a copy of the License at
7*ec779b8eSAndroid Build Coastguard Worker  *
8*ec779b8eSAndroid Build Coastguard Worker  *      http://www.apache.org/licenses/LICENSE-2.0
9*ec779b8eSAndroid Build Coastguard Worker  *
10*ec779b8eSAndroid Build Coastguard Worker  * Unless required by applicable law or agreed to in writing, software
11*ec779b8eSAndroid Build Coastguard Worker  * distributed under the License is distributed on an "AS IS" BASIS,
12*ec779b8eSAndroid Build Coastguard Worker  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*ec779b8eSAndroid Build Coastguard Worker  * See the License for the specific language governing permissions and
14*ec779b8eSAndroid Build Coastguard Worker  * limitations under the License.
15*ec779b8eSAndroid Build Coastguard Worker  */
16*ec779b8eSAndroid Build Coastguard Worker 
17*ec779b8eSAndroid Build Coastguard Worker #pragma once
18*ec779b8eSAndroid Build Coastguard Worker 
19*ec779b8eSAndroid Build Coastguard Worker #include "PosePredictorVerifier.h"
20*ec779b8eSAndroid Build Coastguard Worker #include <memory>
21*ec779b8eSAndroid Build Coastguard Worker #include <audio_utils/Statistics.h>
22*ec779b8eSAndroid Build Coastguard Worker #include <media/PosePredictorType.h>
23*ec779b8eSAndroid Build Coastguard Worker #include <media/Twist.h>
24*ec779b8eSAndroid Build Coastguard Worker #include <media/VectorRecorder.h>
25*ec779b8eSAndroid Build Coastguard Worker 
26*ec779b8eSAndroid Build Coastguard Worker namespace android::media {
27*ec779b8eSAndroid Build Coastguard Worker 
28*ec779b8eSAndroid Build Coastguard Worker // Interface for generic pose predictors
29*ec779b8eSAndroid Build Coastguard Worker class PredictorBase {
30*ec779b8eSAndroid Build Coastguard Worker public:
31*ec779b8eSAndroid Build Coastguard Worker     virtual ~PredictorBase() = default;
32*ec779b8eSAndroid Build Coastguard Worker     virtual void add(int64_t atNs, const Pose3f& pose, const Twist3f& twist) = 0;
33*ec779b8eSAndroid Build Coastguard Worker     virtual Pose3f predict(int64_t atNs) const = 0;
34*ec779b8eSAndroid Build Coastguard Worker     virtual void reset() = 0;
35*ec779b8eSAndroid Build Coastguard Worker     virtual std::string name() const = 0;
36*ec779b8eSAndroid Build Coastguard Worker     virtual std::string toString(size_t index) const = 0;
37*ec779b8eSAndroid Build Coastguard Worker };
38*ec779b8eSAndroid Build Coastguard Worker 
39*ec779b8eSAndroid Build Coastguard Worker /**
40*ec779b8eSAndroid Build Coastguard Worker  * LastPredictor uses the last sample Pose for prediction
41*ec779b8eSAndroid Build Coastguard Worker  *
42*ec779b8eSAndroid Build Coastguard Worker  * This class is not thread-safe.
43*ec779b8eSAndroid Build Coastguard Worker  */
44*ec779b8eSAndroid Build Coastguard Worker class LastPredictor : public PredictorBase {
45*ec779b8eSAndroid Build Coastguard Worker public:
add(int64_t atNs,const Pose3f & pose,const Twist3f & twist)46*ec779b8eSAndroid Build Coastguard Worker     void add(int64_t atNs, const Pose3f& pose, const Twist3f& twist) override {
47*ec779b8eSAndroid Build Coastguard Worker         (void)atNs;
48*ec779b8eSAndroid Build Coastguard Worker         (void)twist;
49*ec779b8eSAndroid Build Coastguard Worker         mLastPose = pose;
50*ec779b8eSAndroid Build Coastguard Worker     }
51*ec779b8eSAndroid Build Coastguard Worker 
predict(int64_t atNs)52*ec779b8eSAndroid Build Coastguard Worker     Pose3f predict(int64_t atNs) const override {
53*ec779b8eSAndroid Build Coastguard Worker         (void)atNs;
54*ec779b8eSAndroid Build Coastguard Worker         return mLastPose;
55*ec779b8eSAndroid Build Coastguard Worker     }
56*ec779b8eSAndroid Build Coastguard Worker 
reset()57*ec779b8eSAndroid Build Coastguard Worker     void reset() override {
58*ec779b8eSAndroid Build Coastguard Worker         mLastPose = {};
59*ec779b8eSAndroid Build Coastguard Worker     }
60*ec779b8eSAndroid Build Coastguard Worker 
name()61*ec779b8eSAndroid Build Coastguard Worker     std::string name() const override {
62*ec779b8eSAndroid Build Coastguard Worker         return "LAST";
63*ec779b8eSAndroid Build Coastguard Worker     }
64*ec779b8eSAndroid Build Coastguard Worker 
toString(size_t index)65*ec779b8eSAndroid Build Coastguard Worker     std::string toString(size_t index) const override {
66*ec779b8eSAndroid Build Coastguard Worker         std::string s(index, ' ');
67*ec779b8eSAndroid Build Coastguard Worker         s.append("LastPredictor using last pose: ")
68*ec779b8eSAndroid Build Coastguard Worker             .append(mLastPose.toString())
69*ec779b8eSAndroid Build Coastguard Worker             .append("\n");
70*ec779b8eSAndroid Build Coastguard Worker         return s;
71*ec779b8eSAndroid Build Coastguard Worker     }
72*ec779b8eSAndroid Build Coastguard Worker 
73*ec779b8eSAndroid Build Coastguard Worker private:
74*ec779b8eSAndroid Build Coastguard Worker     Pose3f mLastPose;
75*ec779b8eSAndroid Build Coastguard Worker };
76*ec779b8eSAndroid Build Coastguard Worker 
77*ec779b8eSAndroid Build Coastguard Worker /**
78*ec779b8eSAndroid Build Coastguard Worker  * TwistPredictor uses the last sample Twist and Pose for prediction
79*ec779b8eSAndroid Build Coastguard Worker  *
80*ec779b8eSAndroid Build Coastguard Worker  * This class is not thread-safe.
81*ec779b8eSAndroid Build Coastguard Worker  */
82*ec779b8eSAndroid Build Coastguard Worker class TwistPredictor : public PredictorBase {
83*ec779b8eSAndroid Build Coastguard Worker public:
add(int64_t atNs,const Pose3f & pose,const Twist3f & twist)84*ec779b8eSAndroid Build Coastguard Worker     void add(int64_t atNs, const Pose3f& pose, const Twist3f& twist) override {
85*ec779b8eSAndroid Build Coastguard Worker         mLastAtNs = atNs;
86*ec779b8eSAndroid Build Coastguard Worker         mLastPose = pose;
87*ec779b8eSAndroid Build Coastguard Worker         mLastTwist = twist;
88*ec779b8eSAndroid Build Coastguard Worker     }
89*ec779b8eSAndroid Build Coastguard Worker 
predict(int64_t atNs)90*ec779b8eSAndroid Build Coastguard Worker     Pose3f predict(int64_t atNs) const override {
91*ec779b8eSAndroid Build Coastguard Worker         return mLastPose * integrate(mLastTwist, atNs - mLastAtNs);
92*ec779b8eSAndroid Build Coastguard Worker     }
93*ec779b8eSAndroid Build Coastguard Worker 
reset()94*ec779b8eSAndroid Build Coastguard Worker     void reset() override {
95*ec779b8eSAndroid Build Coastguard Worker         mLastAtNs = {};
96*ec779b8eSAndroid Build Coastguard Worker         mLastPose = {};
97*ec779b8eSAndroid Build Coastguard Worker         mLastTwist = {};
98*ec779b8eSAndroid Build Coastguard Worker     }
99*ec779b8eSAndroid Build Coastguard Worker 
name()100*ec779b8eSAndroid Build Coastguard Worker     std::string name() const override {
101*ec779b8eSAndroid Build Coastguard Worker         return "TWIST";
102*ec779b8eSAndroid Build Coastguard Worker     }
103*ec779b8eSAndroid Build Coastguard Worker 
toString(size_t index)104*ec779b8eSAndroid Build Coastguard Worker     std::string toString(size_t index) const override {
105*ec779b8eSAndroid Build Coastguard Worker         std::string s(index, ' ');
106*ec779b8eSAndroid Build Coastguard Worker         s.append("TwistPredictor using last pose: ")
107*ec779b8eSAndroid Build Coastguard Worker             .append(mLastPose.toString())
108*ec779b8eSAndroid Build Coastguard Worker             .append(" last twist: ")
109*ec779b8eSAndroid Build Coastguard Worker             .append(mLastTwist.toString())
110*ec779b8eSAndroid Build Coastguard Worker             .append("\n");
111*ec779b8eSAndroid Build Coastguard Worker         return s;
112*ec779b8eSAndroid Build Coastguard Worker     }
113*ec779b8eSAndroid Build Coastguard Worker 
114*ec779b8eSAndroid Build Coastguard Worker private:
115*ec779b8eSAndroid Build Coastguard Worker     int64_t mLastAtNs{};
116*ec779b8eSAndroid Build Coastguard Worker     Pose3f mLastPose;
117*ec779b8eSAndroid Build Coastguard Worker     Twist3f mLastTwist;
118*ec779b8eSAndroid Build Coastguard Worker };
119*ec779b8eSAndroid Build Coastguard Worker 
120*ec779b8eSAndroid Build Coastguard Worker 
121*ec779b8eSAndroid Build Coastguard Worker /**
122*ec779b8eSAndroid Build Coastguard Worker  * LeastSquaresPredictor uses the Pose history for prediction.
123*ec779b8eSAndroid Build Coastguard Worker  *
124*ec779b8eSAndroid Build Coastguard Worker  * A exponential weighted least squares is used.
125*ec779b8eSAndroid Build Coastguard Worker  *
126*ec779b8eSAndroid Build Coastguard Worker  * This class is not thread-safe.
127*ec779b8eSAndroid Build Coastguard Worker  */
128*ec779b8eSAndroid Build Coastguard Worker class LeastSquaresPredictor : public PredictorBase {
129*ec779b8eSAndroid Build Coastguard Worker public:
130*ec779b8eSAndroid Build Coastguard Worker     // alpha is the exponential decay.
131*ec779b8eSAndroid Build Coastguard Worker     LeastSquaresPredictor(double alpha = kDefaultAlphaEstimator)
mAlpha(alpha)132*ec779b8eSAndroid Build Coastguard Worker         : mAlpha(alpha)
133*ec779b8eSAndroid Build Coastguard Worker         , mRw(alpha)
134*ec779b8eSAndroid Build Coastguard Worker         , mRx(alpha)
135*ec779b8eSAndroid Build Coastguard Worker         , mRy(alpha)
136*ec779b8eSAndroid Build Coastguard Worker         , mRz(alpha)
137*ec779b8eSAndroid Build Coastguard Worker         {}
138*ec779b8eSAndroid Build Coastguard Worker 
139*ec779b8eSAndroid Build Coastguard Worker     void add(int64_t atNs, const Pose3f& pose, const Twist3f& twist) override;
140*ec779b8eSAndroid Build Coastguard Worker     Pose3f predict(int64_t atNs) const override;
141*ec779b8eSAndroid Build Coastguard Worker     void reset() override;
name()142*ec779b8eSAndroid Build Coastguard Worker     std::string name() const override {
143*ec779b8eSAndroid Build Coastguard Worker         return "LEAST_SQUARES(" + std::to_string(mAlpha) + ")";
144*ec779b8eSAndroid Build Coastguard Worker     }
145*ec779b8eSAndroid Build Coastguard Worker     std::string toString(size_t index) const override;
146*ec779b8eSAndroid Build Coastguard Worker 
147*ec779b8eSAndroid Build Coastguard Worker private:
148*ec779b8eSAndroid Build Coastguard Worker     const double mAlpha;
149*ec779b8eSAndroid Build Coastguard Worker     int64_t mLastAtNs{};
150*ec779b8eSAndroid Build Coastguard Worker     Pose3f mLastPose;
151*ec779b8eSAndroid Build Coastguard Worker     static constexpr double kDefaultAlphaEstimator = 0.2;
152*ec779b8eSAndroid Build Coastguard Worker     static constexpr size_t kMinimumSamplesForPrediction = 4;
153*ec779b8eSAndroid Build Coastguard Worker     audio_utils::LinearLeastSquaresFit<double> mRw;
154*ec779b8eSAndroid Build Coastguard Worker     audio_utils::LinearLeastSquaresFit<double> mRx;
155*ec779b8eSAndroid Build Coastguard Worker     audio_utils::LinearLeastSquaresFit<double> mRy;
156*ec779b8eSAndroid Build Coastguard Worker     audio_utils::LinearLeastSquaresFit<double> mRz;
157*ec779b8eSAndroid Build Coastguard Worker };
158*ec779b8eSAndroid Build Coastguard Worker 
159*ec779b8eSAndroid Build Coastguard Worker /*
160*ec779b8eSAndroid Build Coastguard Worker  * PosePredictor predicts the pose given sensor input at a time in the future.
161*ec779b8eSAndroid Build Coastguard Worker  *
162*ec779b8eSAndroid Build Coastguard Worker  * This class is not thread safe.
163*ec779b8eSAndroid Build Coastguard Worker  */
164*ec779b8eSAndroid Build Coastguard Worker class PosePredictor {
165*ec779b8eSAndroid Build Coastguard Worker public:
166*ec779b8eSAndroid Build Coastguard Worker     PosePredictor();
167*ec779b8eSAndroid Build Coastguard Worker 
168*ec779b8eSAndroid Build Coastguard Worker     Pose3f predict(int64_t timestampNs, const Pose3f& pose, const Twist3f& twist,
169*ec779b8eSAndroid Build Coastguard Worker             float predictionDurationNs);
170*ec779b8eSAndroid Build Coastguard Worker 
171*ec779b8eSAndroid Build Coastguard Worker     void setPosePredictorType(PosePredictorType type);
172*ec779b8eSAndroid Build Coastguard Worker 
173*ec779b8eSAndroid Build Coastguard Worker     // convert predictions to a printable string
174*ec779b8eSAndroid Build Coastguard Worker     std::string toString(size_t index) const;
175*ec779b8eSAndroid Build Coastguard Worker 
176*ec779b8eSAndroid Build Coastguard Worker private:
177*ec779b8eSAndroid Build Coastguard Worker     static constexpr int64_t kMaximumSampleIntervalBeforeResetNs =
178*ec779b8eSAndroid Build Coastguard Worker             300'000'000;
179*ec779b8eSAndroid Build Coastguard Worker 
180*ec779b8eSAndroid Build Coastguard Worker     // Predictors
181*ec779b8eSAndroid Build Coastguard Worker     const std::vector<std::shared_ptr<PredictorBase>> mPredictors;
182*ec779b8eSAndroid Build Coastguard Worker 
183*ec779b8eSAndroid Build Coastguard Worker     // Verifiers, create one for an array of future lookaheads for comparison.
184*ec779b8eSAndroid Build Coastguard Worker     const std::vector<int> mLookaheadMs;
185*ec779b8eSAndroid Build Coastguard Worker 
186*ec779b8eSAndroid Build Coastguard Worker     std::vector<PosePredictorVerifier> mVerifiers;
187*ec779b8eSAndroid Build Coastguard Worker 
188*ec779b8eSAndroid Build Coastguard Worker     const std::vector<size_t> mDelimiterIdx;
189*ec779b8eSAndroid Build Coastguard Worker 
190*ec779b8eSAndroid Build Coastguard Worker     // Recorders
191*ec779b8eSAndroid Build Coastguard Worker     media::VectorRecorder mPredictionRecorder{
192*ec779b8eSAndroid Build Coastguard Worker         std::size(mVerifiers) /* vectorSize */, std::chrono::seconds(1), 10 /* maxLogLine */,
193*ec779b8eSAndroid Build Coastguard Worker         mDelimiterIdx};
194*ec779b8eSAndroid Build Coastguard Worker     media::VectorRecorder mPredictionDurableRecorder{
195*ec779b8eSAndroid Build Coastguard Worker         std::size(mVerifiers) /* vectorSize */, std::chrono::minutes(1), 10 /* maxLogLine */,
196*ec779b8eSAndroid Build Coastguard Worker         mDelimiterIdx};
197*ec779b8eSAndroid Build Coastguard Worker 
198*ec779b8eSAndroid Build Coastguard Worker     // Status
199*ec779b8eSAndroid Build Coastguard Worker 
200*ec779b8eSAndroid Build Coastguard Worker     // SetType is the externally set predictor type.  It may include AUTO.
201*ec779b8eSAndroid Build Coastguard Worker     PosePredictorType mSetType = PosePredictorType::LEAST_SQUARES;
202*ec779b8eSAndroid Build Coastguard Worker 
203*ec779b8eSAndroid Build Coastguard Worker     // CurrentType is the actual predictor type used by this class.
204*ec779b8eSAndroid Build Coastguard Worker     // It does not include AUTO because that metatype means the class
205*ec779b8eSAndroid Build Coastguard Worker     // chooses the best predictor type based on sensor statistics.
206*ec779b8eSAndroid Build Coastguard Worker     PosePredictorType mCurrentType = PosePredictorType::LEAST_SQUARES;
207*ec779b8eSAndroid Build Coastguard Worker 
208*ec779b8eSAndroid Build Coastguard Worker     int64_t mResets{};
209*ec779b8eSAndroid Build Coastguard Worker     int64_t mLastTimestampNs{};
210*ec779b8eSAndroid Build Coastguard Worker 
211*ec779b8eSAndroid Build Coastguard Worker     // Returns current predictor
212*ec779b8eSAndroid Build Coastguard Worker     std::shared_ptr<PredictorBase> getCurrentPredictor() const;
213*ec779b8eSAndroid Build Coastguard Worker };
214*ec779b8eSAndroid Build Coastguard Worker 
215*ec779b8eSAndroid Build Coastguard Worker }  // namespace android::media
216