1 /*
2 * Copyright 2023 Google LLC
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "fcp/secagg/server/distribution_utilities.h"
18
19 #include <cmath>
20 #include <iostream>
21 #include <memory>
22
23 namespace fcp {
24 namespace secagg {
25
26 StatusOr<std::unique_ptr<HypergeometricDistribution>>
Create(int total,int marked,int sampled)27 HypergeometricDistribution::Create(int total, int marked, int sampled) {
28 if (total < 0) {
29 return FCP_STATUS(FAILED_PRECONDITION)
30 << "The population should be at least zero. Value provided = "
31 << total;
32 }
33 if (marked < 0) {
34 return FCP_STATUS(FAILED_PRECONDITION)
35 << "The marked population should have size at least zero. Value "
36 "provided = "
37 << marked;
38 }
39 if (sampled < 0) {
40 return FCP_STATUS(FAILED_PRECONDITION)
41 << "The sample size should be at least zero. Value provided = "
42 << sampled;
43 }
44 if (marked > total) {
45 return FCP_STATUS(FAILED_PRECONDITION)
46 << "The marked population " << marked
47 << " should not exceed the total population " << total;
48 }
49 if (sampled > total) {
50 return FCP_STATUS(FAILED_PRECONDITION)
51 << "The sample size " << sampled
52 << " should not exceed the total population " << total;
53 }
54 return std::unique_ptr<HypergeometricDistribution>(
55 new HypergeometricDistribution(total, marked, sampled));
56 }
57
PMF(double x)58 double HypergeometricDistribution::PMF(double x) { return PMFImpl(x, marked_); }
59
PMFImpl(double x,int counted)60 double HypergeometricDistribution::PMFImpl(double x, int counted) {
61 if (x < 0 || x > sampled_ || x > counted) return 0;
62 if (total_ + x < counted + sampled_) return 0;
63 double lpmf = std::lgamma(sampled_ + 1) + std::lgamma(counted + 1) +
64 std::lgamma(total_ - counted + 1) +
65 std::lgamma(total_ - sampled_ + 1) - std::lgamma(x + 1) -
66 std::lgamma(sampled_ - x + 1) - std::lgamma(counted - x + 1) -
67 std::lgamma(total_ + 1) -
68 std::lgamma(total_ - sampled_ - counted + x + 1);
69 return std::exp(lpmf);
70 }
71
CDF(double x)72 double HypergeometricDistribution::CDF(double x) {
73 x = std::floor(x);
74 double mean = marked_ * static_cast<double>(sampled_) / total_;
75 if (x > mean) {
76 return 1 - CDFImpl(sampled_ - x - 1, total_ - marked_);
77 } else {
78 return CDFImpl(x, marked_);
79 }
80 }
81
CDFImpl(double x,int counted)82 double HypergeometricDistribution::CDFImpl(double x, int counted) {
83 double current_pmf = PMFImpl(x, counted);
84 double result = 0;
85 while (current_pmf > result * 1e-16) {
86 result += current_pmf;
87 current_pmf *= x;
88 current_pmf *= total_ - counted - sampled_ + x;
89 current_pmf /= counted - x + 1;
90 current_pmf /= sampled_ - x + 1;
91 --x;
92 }
93 return result;
94 }
95
FindQuantile(double quantile,bool complement)96 double HypergeometricDistribution::FindQuantile(double quantile,
97 bool complement) {
98 if (quantile > 0.5) {
99 quantile = 1 - quantile;
100 complement = !complement;
101 }
102 if (complement) {
103 return sampled_ - FindQuantileImpl(quantile, total_ - marked_) - 1;
104 } else {
105 return FindQuantileImpl(quantile, marked_);
106 }
107 }
108
FindQuantileImpl(double quantile,int counted)109 double HypergeometricDistribution::FindQuantileImpl(double quantile,
110 int counted) {
111 double basic_bound = counted + sampled_ - total_ - 1;
112 // An inverted tail bound gives a lower bound on the result
113 double fancy_bound =
114 sampled_ * (static_cast<double>(counted) / total_ -
115 std::sqrt(-std::log(quantile) / (2 * sampled_)));
116 double result = -1;
117 if (basic_bound > result) result = basic_bound;
118 if (fancy_bound > result) result = fancy_bound;
119 result = std::floor(result);
120
121 double current_cdf = CDFImpl(result, counted);
122 double current_pmf = PMFImpl(result, counted);
123 while (current_cdf < quantile && result < sampled_) {
124 if (current_pmf > 0) {
125 current_pmf /= result + 1;
126 current_pmf /= total_ - counted - sampled_ + result + 1;
127 current_pmf *= counted - result;
128 current_pmf *= sampled_ - result;
129 } else {
130 current_pmf = PMFImpl(result + 1, counted);
131 }
132 current_cdf += current_pmf;
133 ++result;
134 }
135 --result;
136 return result;
137 }
138
139 } // namespace secagg
140 } // namespace fcp
141