xref: /aosp_15_r20/external/pytorch/c10/util/ApproximateClock.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker #include <c10/util/ApproximateClock.h>
2*da0073e9SAndroid Build Coastguard Worker #include <c10/util/ArrayRef.h>
3*da0073e9SAndroid Build Coastguard Worker #include <c10/util/irange.h>
4*da0073e9SAndroid Build Coastguard Worker #include <fmt/format.h>
5*da0073e9SAndroid Build Coastguard Worker 
6*da0073e9SAndroid Build Coastguard Worker namespace c10 {
7*da0073e9SAndroid Build Coastguard Worker 
ApproximateClockToUnixTimeConverter()8*da0073e9SAndroid Build Coastguard Worker ApproximateClockToUnixTimeConverter::ApproximateClockToUnixTimeConverter()
9*da0073e9SAndroid Build Coastguard Worker     : start_times_(measurePairs()) {}
10*da0073e9SAndroid Build Coastguard Worker 
11*da0073e9SAndroid Build Coastguard Worker ApproximateClockToUnixTimeConverter::UnixAndApproximateTimePair
measurePair()12*da0073e9SAndroid Build Coastguard Worker ApproximateClockToUnixTimeConverter::measurePair() {
13*da0073e9SAndroid Build Coastguard Worker   // Take a measurement on either side to avoid an ordering bias.
14*da0073e9SAndroid Build Coastguard Worker   auto fast_0 = getApproximateTime();
15*da0073e9SAndroid Build Coastguard Worker   auto wall = std::chrono::system_clock::now();
16*da0073e9SAndroid Build Coastguard Worker   auto fast_1 = getApproximateTime();
17*da0073e9SAndroid Build Coastguard Worker 
18*da0073e9SAndroid Build Coastguard Worker   TORCH_INTERNAL_ASSERT(fast_1 >= fast_0, "getCount is non-monotonic.");
19*da0073e9SAndroid Build Coastguard Worker   auto t = std::chrono::duration_cast<std::chrono::nanoseconds>(
20*da0073e9SAndroid Build Coastguard Worker       wall.time_since_epoch());
21*da0073e9SAndroid Build Coastguard Worker 
22*da0073e9SAndroid Build Coastguard Worker   // `x + (y - x) / 2` is a more numerically stable average than `(x + y) / 2`.
23*da0073e9SAndroid Build Coastguard Worker   return {t.count(), fast_0 + (fast_1 - fast_0) / 2};
24*da0073e9SAndroid Build Coastguard Worker }
25*da0073e9SAndroid Build Coastguard Worker 
26*da0073e9SAndroid Build Coastguard Worker ApproximateClockToUnixTimeConverter::time_pairs
measurePairs()27*da0073e9SAndroid Build Coastguard Worker ApproximateClockToUnixTimeConverter::measurePairs() {
28*da0073e9SAndroid Build Coastguard Worker   static constexpr auto n_warmup = 5;
29*da0073e9SAndroid Build Coastguard Worker   for (C10_UNUSED const auto _ : c10::irange(n_warmup)) {
30*da0073e9SAndroid Build Coastguard Worker     getApproximateTime();
31*da0073e9SAndroid Build Coastguard Worker     static_cast<void>(steady_clock_t::now());
32*da0073e9SAndroid Build Coastguard Worker   }
33*da0073e9SAndroid Build Coastguard Worker 
34*da0073e9SAndroid Build Coastguard Worker   time_pairs out;
35*da0073e9SAndroid Build Coastguard Worker   for (const auto i : c10::irange(out.size())) {
36*da0073e9SAndroid Build Coastguard Worker     out[i] = measurePair();
37*da0073e9SAndroid Build Coastguard Worker   }
38*da0073e9SAndroid Build Coastguard Worker   return out;
39*da0073e9SAndroid Build Coastguard Worker }
40*da0073e9SAndroid Build Coastguard Worker 
41*da0073e9SAndroid Build Coastguard Worker std::function<time_t(approx_time_t)> ApproximateClockToUnixTimeConverter::
makeConverter()42*da0073e9SAndroid Build Coastguard Worker     makeConverter() {
43*da0073e9SAndroid Build Coastguard Worker   auto end_times = measurePairs();
44*da0073e9SAndroid Build Coastguard Worker 
45*da0073e9SAndroid Build Coastguard Worker   // Compute the real time that passes for each tick of the approximate clock.
46*da0073e9SAndroid Build Coastguard Worker   std::array<long double, replicates> scale_factors{};
47*da0073e9SAndroid Build Coastguard Worker   for (const auto i : c10::irange(replicates)) {
48*da0073e9SAndroid Build Coastguard Worker     auto delta_ns = end_times[i].t_ - start_times_[i].t_;
49*da0073e9SAndroid Build Coastguard Worker     auto delta_approx = end_times[i].approx_t_ - start_times_[i].approx_t_;
50*da0073e9SAndroid Build Coastguard Worker     scale_factors[i] = (double)delta_ns / (double)delta_approx;
51*da0073e9SAndroid Build Coastguard Worker   }
52*da0073e9SAndroid Build Coastguard Worker   std::sort(scale_factors.begin(), scale_factors.end());
53*da0073e9SAndroid Build Coastguard Worker   long double scale_factor = scale_factors[replicates / 2 + 1];
54*da0073e9SAndroid Build Coastguard Worker 
55*da0073e9SAndroid Build Coastguard Worker   // We shift all times by `t0` for better numerics. Double precision only has
56*da0073e9SAndroid Build Coastguard Worker   // 16 decimal digits of accuracy, so if we blindly multiply times by
57*da0073e9SAndroid Build Coastguard Worker   // `scale_factor` we may suffer from precision loss. The choice of `t0` is
58*da0073e9SAndroid Build Coastguard Worker   // mostly arbitrary; we just need a factor that is the correct order of
59*da0073e9SAndroid Build Coastguard Worker   // magnitude to bring the intermediate values closer to zero. We are not,
60*da0073e9SAndroid Build Coastguard Worker   // however, guaranteed that `t0_approx` is *exactly* the getApproximateTime
61*da0073e9SAndroid Build Coastguard Worker   // equivalent of `t0`; it is only an estimate that we have to fine tune.
62*da0073e9SAndroid Build Coastguard Worker   auto t0 = start_times_[0].t_;
63*da0073e9SAndroid Build Coastguard Worker   auto t0_approx = start_times_[0].approx_t_;
64*da0073e9SAndroid Build Coastguard Worker   std::array<double, replicates> t0_correction{};
65*da0073e9SAndroid Build Coastguard Worker   for (const auto i : c10::irange(replicates)) {
66*da0073e9SAndroid Build Coastguard Worker     auto dt = start_times_[i].t_ - t0;
67*da0073e9SAndroid Build Coastguard Worker     auto dt_approx =
68*da0073e9SAndroid Build Coastguard Worker         (double)(start_times_[i].approx_t_ - t0_approx) * scale_factor;
69*da0073e9SAndroid Build Coastguard Worker     t0_correction[i] = dt - (time_t)dt_approx; // NOLINT
70*da0073e9SAndroid Build Coastguard Worker   }
71*da0073e9SAndroid Build Coastguard Worker   t0 += t0_correction[t0_correction.size() / 2 + 1]; // NOLINT
72*da0073e9SAndroid Build Coastguard Worker 
73*da0073e9SAndroid Build Coastguard Worker   return [=](approx_time_t t_approx) {
74*da0073e9SAndroid Build Coastguard Worker     // See above for why this is more stable than `A * t_approx + B`.
75*da0073e9SAndroid Build Coastguard Worker     return (time_t)((double)(t_approx - t0_approx) * scale_factor) + t0;
76*da0073e9SAndroid Build Coastguard Worker   };
77*da0073e9SAndroid Build Coastguard Worker }
78*da0073e9SAndroid Build Coastguard Worker 
79*da0073e9SAndroid Build Coastguard Worker } // namespace c10
80