xref: /aosp_15_r20/external/federated-compute/fcp/secagg/server/distribution_utilities.cc (revision 14675a029014e728ec732f129a32e299b2da0601)
1*14675a02SAndroid Build Coastguard Worker /*
2*14675a02SAndroid Build Coastguard Worker  * Copyright 2023 Google LLC
3*14675a02SAndroid Build Coastguard Worker  *
4*14675a02SAndroid Build Coastguard Worker  * Licensed under the Apache License, Version 2.0 (the "License");
5*14675a02SAndroid Build Coastguard Worker  * you may not use this file except in compliance with the License.
6*14675a02SAndroid Build Coastguard Worker  * You may obtain a copy of the License at
7*14675a02SAndroid Build Coastguard Worker  *
8*14675a02SAndroid Build Coastguard Worker  *      http://www.apache.org/licenses/LICENSE-2.0
9*14675a02SAndroid Build Coastguard Worker  *
10*14675a02SAndroid Build Coastguard Worker  * Unless required by applicable law or agreed to in writing, software
11*14675a02SAndroid Build Coastguard Worker  * distributed under the License is distributed on an "AS IS" BASIS,
12*14675a02SAndroid Build Coastguard Worker  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*14675a02SAndroid Build Coastguard Worker  * See the License for the specific language governing permissions and
14*14675a02SAndroid Build Coastguard Worker  * limitations under the License.
15*14675a02SAndroid Build Coastguard Worker  */
16*14675a02SAndroid Build Coastguard Worker 
17*14675a02SAndroid Build Coastguard Worker #include "fcp/secagg/server/distribution_utilities.h"
18*14675a02SAndroid Build Coastguard Worker 
19*14675a02SAndroid Build Coastguard Worker #include <cmath>
20*14675a02SAndroid Build Coastguard Worker #include <iostream>
21*14675a02SAndroid Build Coastguard Worker #include <memory>
22*14675a02SAndroid Build Coastguard Worker 
23*14675a02SAndroid Build Coastguard Worker namespace fcp {
24*14675a02SAndroid Build Coastguard Worker namespace secagg {
25*14675a02SAndroid Build Coastguard Worker 
26*14675a02SAndroid Build Coastguard Worker StatusOr<std::unique_ptr<HypergeometricDistribution>>
Create(int total,int marked,int sampled)27*14675a02SAndroid Build Coastguard Worker HypergeometricDistribution::Create(int total, int marked, int sampled) {
28*14675a02SAndroid Build Coastguard Worker   if (total < 0) {
29*14675a02SAndroid Build Coastguard Worker     return FCP_STATUS(FAILED_PRECONDITION)
30*14675a02SAndroid Build Coastguard Worker            << "The population should be at least zero. Value provided = "
31*14675a02SAndroid Build Coastguard Worker            << total;
32*14675a02SAndroid Build Coastguard Worker   }
33*14675a02SAndroid Build Coastguard Worker   if (marked < 0) {
34*14675a02SAndroid Build Coastguard Worker     return FCP_STATUS(FAILED_PRECONDITION)
35*14675a02SAndroid Build Coastguard Worker            << "The marked population should have size at least zero. Value "
36*14675a02SAndroid Build Coastguard Worker               "provided = "
37*14675a02SAndroid Build Coastguard Worker            << marked;
38*14675a02SAndroid Build Coastguard Worker   }
39*14675a02SAndroid Build Coastguard Worker   if (sampled < 0) {
40*14675a02SAndroid Build Coastguard Worker     return FCP_STATUS(FAILED_PRECONDITION)
41*14675a02SAndroid Build Coastguard Worker            << "The sample size should be at least zero. Value provided = "
42*14675a02SAndroid Build Coastguard Worker            << sampled;
43*14675a02SAndroid Build Coastguard Worker   }
44*14675a02SAndroid Build Coastguard Worker   if (marked > total) {
45*14675a02SAndroid Build Coastguard Worker     return FCP_STATUS(FAILED_PRECONDITION)
46*14675a02SAndroid Build Coastguard Worker            << "The marked population " << marked
47*14675a02SAndroid Build Coastguard Worker            << " should not exceed the total population " << total;
48*14675a02SAndroid Build Coastguard Worker   }
49*14675a02SAndroid Build Coastguard Worker   if (sampled > total) {
50*14675a02SAndroid Build Coastguard Worker     return FCP_STATUS(FAILED_PRECONDITION)
51*14675a02SAndroid Build Coastguard Worker            << "The sample size " << sampled
52*14675a02SAndroid Build Coastguard Worker            << " should not exceed the total population " << total;
53*14675a02SAndroid Build Coastguard Worker   }
54*14675a02SAndroid Build Coastguard Worker   return std::unique_ptr<HypergeometricDistribution>(
55*14675a02SAndroid Build Coastguard Worker       new HypergeometricDistribution(total, marked, sampled));
56*14675a02SAndroid Build Coastguard Worker }
57*14675a02SAndroid Build Coastguard Worker 
PMF(double x)58*14675a02SAndroid Build Coastguard Worker double HypergeometricDistribution::PMF(double x) { return PMFImpl(x, marked_); }
59*14675a02SAndroid Build Coastguard Worker 
PMFImpl(double x,int counted)60*14675a02SAndroid Build Coastguard Worker double HypergeometricDistribution::PMFImpl(double x, int counted) {
61*14675a02SAndroid Build Coastguard Worker   if (x < 0 || x > sampled_ || x > counted) return 0;
62*14675a02SAndroid Build Coastguard Worker   if (total_ + x < counted + sampled_) return 0;
63*14675a02SAndroid Build Coastguard Worker   double lpmf = std::lgamma(sampled_ + 1) + std::lgamma(counted + 1) +
64*14675a02SAndroid Build Coastguard Worker                 std::lgamma(total_ - counted + 1) +
65*14675a02SAndroid Build Coastguard Worker                 std::lgamma(total_ - sampled_ + 1) - std::lgamma(x + 1) -
66*14675a02SAndroid Build Coastguard Worker                 std::lgamma(sampled_ - x + 1) - std::lgamma(counted - x + 1) -
67*14675a02SAndroid Build Coastguard Worker                 std::lgamma(total_ + 1) -
68*14675a02SAndroid Build Coastguard Worker                 std::lgamma(total_ - sampled_ - counted + x + 1);
69*14675a02SAndroid Build Coastguard Worker   return std::exp(lpmf);
70*14675a02SAndroid Build Coastguard Worker }
71*14675a02SAndroid Build Coastguard Worker 
CDF(double x)72*14675a02SAndroid Build Coastguard Worker double HypergeometricDistribution::CDF(double x) {
73*14675a02SAndroid Build Coastguard Worker   x = std::floor(x);
74*14675a02SAndroid Build Coastguard Worker   double mean = marked_ * static_cast<double>(sampled_) / total_;
75*14675a02SAndroid Build Coastguard Worker   if (x > mean) {
76*14675a02SAndroid Build Coastguard Worker     return 1 - CDFImpl(sampled_ - x - 1, total_ - marked_);
77*14675a02SAndroid Build Coastguard Worker   } else {
78*14675a02SAndroid Build Coastguard Worker     return CDFImpl(x, marked_);
79*14675a02SAndroid Build Coastguard Worker   }
80*14675a02SAndroid Build Coastguard Worker }
81*14675a02SAndroid Build Coastguard Worker 
CDFImpl(double x,int counted)82*14675a02SAndroid Build Coastguard Worker double HypergeometricDistribution::CDFImpl(double x, int counted) {
83*14675a02SAndroid Build Coastguard Worker   double current_pmf = PMFImpl(x, counted);
84*14675a02SAndroid Build Coastguard Worker   double result = 0;
85*14675a02SAndroid Build Coastguard Worker   while (current_pmf > result * 1e-16) {
86*14675a02SAndroid Build Coastguard Worker     result += current_pmf;
87*14675a02SAndroid Build Coastguard Worker     current_pmf *= x;
88*14675a02SAndroid Build Coastguard Worker     current_pmf *= total_ - counted - sampled_ + x;
89*14675a02SAndroid Build Coastguard Worker     current_pmf /= counted - x + 1;
90*14675a02SAndroid Build Coastguard Worker     current_pmf /= sampled_ - x + 1;
91*14675a02SAndroid Build Coastguard Worker     --x;
92*14675a02SAndroid Build Coastguard Worker   }
93*14675a02SAndroid Build Coastguard Worker   return result;
94*14675a02SAndroid Build Coastguard Worker }
95*14675a02SAndroid Build Coastguard Worker 
FindQuantile(double quantile,bool complement)96*14675a02SAndroid Build Coastguard Worker double HypergeometricDistribution::FindQuantile(double quantile,
97*14675a02SAndroid Build Coastguard Worker                                                 bool complement) {
98*14675a02SAndroid Build Coastguard Worker   if (quantile > 0.5) {
99*14675a02SAndroid Build Coastguard Worker     quantile = 1 - quantile;
100*14675a02SAndroid Build Coastguard Worker     complement = !complement;
101*14675a02SAndroid Build Coastguard Worker   }
102*14675a02SAndroid Build Coastguard Worker   if (complement) {
103*14675a02SAndroid Build Coastguard Worker     return sampled_ - FindQuantileImpl(quantile, total_ - marked_) - 1;
104*14675a02SAndroid Build Coastguard Worker   } else {
105*14675a02SAndroid Build Coastguard Worker     return FindQuantileImpl(quantile, marked_);
106*14675a02SAndroid Build Coastguard Worker   }
107*14675a02SAndroid Build Coastguard Worker }
108*14675a02SAndroid Build Coastguard Worker 
FindQuantileImpl(double quantile,int counted)109*14675a02SAndroid Build Coastguard Worker double HypergeometricDistribution::FindQuantileImpl(double quantile,
110*14675a02SAndroid Build Coastguard Worker                                                     int counted) {
111*14675a02SAndroid Build Coastguard Worker   double basic_bound = counted + sampled_ - total_ - 1;
112*14675a02SAndroid Build Coastguard Worker   // An inverted tail bound gives a lower bound on the result
113*14675a02SAndroid Build Coastguard Worker   double fancy_bound =
114*14675a02SAndroid Build Coastguard Worker       sampled_ * (static_cast<double>(counted) / total_ -
115*14675a02SAndroid Build Coastguard Worker                   std::sqrt(-std::log(quantile) / (2 * sampled_)));
116*14675a02SAndroid Build Coastguard Worker   double result = -1;
117*14675a02SAndroid Build Coastguard Worker   if (basic_bound > result) result = basic_bound;
118*14675a02SAndroid Build Coastguard Worker   if (fancy_bound > result) result = fancy_bound;
119*14675a02SAndroid Build Coastguard Worker   result = std::floor(result);
120*14675a02SAndroid Build Coastguard Worker 
121*14675a02SAndroid Build Coastguard Worker   double current_cdf = CDFImpl(result, counted);
122*14675a02SAndroid Build Coastguard Worker   double current_pmf = PMFImpl(result, counted);
123*14675a02SAndroid Build Coastguard Worker   while (current_cdf < quantile && result < sampled_) {
124*14675a02SAndroid Build Coastguard Worker     if (current_pmf > 0) {
125*14675a02SAndroid Build Coastguard Worker       current_pmf /= result + 1;
126*14675a02SAndroid Build Coastguard Worker       current_pmf /= total_ - counted - sampled_ + result + 1;
127*14675a02SAndroid Build Coastguard Worker       current_pmf *= counted - result;
128*14675a02SAndroid Build Coastguard Worker       current_pmf *= sampled_ - result;
129*14675a02SAndroid Build Coastguard Worker     } else {
130*14675a02SAndroid Build Coastguard Worker       current_pmf = PMFImpl(result + 1, counted);
131*14675a02SAndroid Build Coastguard Worker     }
132*14675a02SAndroid Build Coastguard Worker     current_cdf += current_pmf;
133*14675a02SAndroid Build Coastguard Worker     ++result;
134*14675a02SAndroid Build Coastguard Worker   }
135*14675a02SAndroid Build Coastguard Worker   --result;
136*14675a02SAndroid Build Coastguard Worker   return result;
137*14675a02SAndroid Build Coastguard Worker }
138*14675a02SAndroid Build Coastguard Worker 
139*14675a02SAndroid Build Coastguard Worker }  // namespace secagg
140*14675a02SAndroid Build Coastguard Worker }  // namespace fcp
141