xref: /aosp_15_r20/external/federated-compute/fcp/secagg/server/distribution_utilities.cc (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2023 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "fcp/secagg/server/distribution_utilities.h"
18 
19 #include <cmath>
20 #include <iostream>
21 #include <memory>
22 
23 namespace fcp {
24 namespace secagg {
25 
26 StatusOr<std::unique_ptr<HypergeometricDistribution>>
Create(int total,int marked,int sampled)27 HypergeometricDistribution::Create(int total, int marked, int sampled) {
28   if (total < 0) {
29     return FCP_STATUS(FAILED_PRECONDITION)
30            << "The population should be at least zero. Value provided = "
31            << total;
32   }
33   if (marked < 0) {
34     return FCP_STATUS(FAILED_PRECONDITION)
35            << "The marked population should have size at least zero. Value "
36               "provided = "
37            << marked;
38   }
39   if (sampled < 0) {
40     return FCP_STATUS(FAILED_PRECONDITION)
41            << "The sample size should be at least zero. Value provided = "
42            << sampled;
43   }
44   if (marked > total) {
45     return FCP_STATUS(FAILED_PRECONDITION)
46            << "The marked population " << marked
47            << " should not exceed the total population " << total;
48   }
49   if (sampled > total) {
50     return FCP_STATUS(FAILED_PRECONDITION)
51            << "The sample size " << sampled
52            << " should not exceed the total population " << total;
53   }
54   return std::unique_ptr<HypergeometricDistribution>(
55       new HypergeometricDistribution(total, marked, sampled));
56 }
57 
PMF(double x)58 double HypergeometricDistribution::PMF(double x) { return PMFImpl(x, marked_); }
59 
PMFImpl(double x,int counted)60 double HypergeometricDistribution::PMFImpl(double x, int counted) {
61   if (x < 0 || x > sampled_ || x > counted) return 0;
62   if (total_ + x < counted + sampled_) return 0;
63   double lpmf = std::lgamma(sampled_ + 1) + std::lgamma(counted + 1) +
64                 std::lgamma(total_ - counted + 1) +
65                 std::lgamma(total_ - sampled_ + 1) - std::lgamma(x + 1) -
66                 std::lgamma(sampled_ - x + 1) - std::lgamma(counted - x + 1) -
67                 std::lgamma(total_ + 1) -
68                 std::lgamma(total_ - sampled_ - counted + x + 1);
69   return std::exp(lpmf);
70 }
71 
CDF(double x)72 double HypergeometricDistribution::CDF(double x) {
73   x = std::floor(x);
74   double mean = marked_ * static_cast<double>(sampled_) / total_;
75   if (x > mean) {
76     return 1 - CDFImpl(sampled_ - x - 1, total_ - marked_);
77   } else {
78     return CDFImpl(x, marked_);
79   }
80 }
81 
CDFImpl(double x,int counted)82 double HypergeometricDistribution::CDFImpl(double x, int counted) {
83   double current_pmf = PMFImpl(x, counted);
84   double result = 0;
85   while (current_pmf > result * 1e-16) {
86     result += current_pmf;
87     current_pmf *= x;
88     current_pmf *= total_ - counted - sampled_ + x;
89     current_pmf /= counted - x + 1;
90     current_pmf /= sampled_ - x + 1;
91     --x;
92   }
93   return result;
94 }
95 
FindQuantile(double quantile,bool complement)96 double HypergeometricDistribution::FindQuantile(double quantile,
97                                                 bool complement) {
98   if (quantile > 0.5) {
99     quantile = 1 - quantile;
100     complement = !complement;
101   }
102   if (complement) {
103     return sampled_ - FindQuantileImpl(quantile, total_ - marked_) - 1;
104   } else {
105     return FindQuantileImpl(quantile, marked_);
106   }
107 }
108 
FindQuantileImpl(double quantile,int counted)109 double HypergeometricDistribution::FindQuantileImpl(double quantile,
110                                                     int counted) {
111   double basic_bound = counted + sampled_ - total_ - 1;
112   // An inverted tail bound gives a lower bound on the result
113   double fancy_bound =
114       sampled_ * (static_cast<double>(counted) / total_ -
115                   std::sqrt(-std::log(quantile) / (2 * sampled_)));
116   double result = -1;
117   if (basic_bound > result) result = basic_bound;
118   if (fancy_bound > result) result = fancy_bound;
119   result = std::floor(result);
120 
121   double current_cdf = CDFImpl(result, counted);
122   double current_pmf = PMFImpl(result, counted);
123   while (current_cdf < quantile && result < sampled_) {
124     if (current_pmf > 0) {
125       current_pmf /= result + 1;
126       current_pmf /= total_ - counted - sampled_ + result + 1;
127       current_pmf *= counted - result;
128       current_pmf *= sampled_ - result;
129     } else {
130       current_pmf = PMFImpl(result + 1, counted);
131     }
132     current_cdf += current_pmf;
133     ++result;
134   }
135   --result;
136   return result;
137 }
138 
139 }  // namespace secagg
140 }  // namespace fcp
141