1*14675a02SAndroid Build Coastguard Worker /*
2*14675a02SAndroid Build Coastguard Worker * Copyright 2023 Google LLC
3*14675a02SAndroid Build Coastguard Worker *
4*14675a02SAndroid Build Coastguard Worker * Licensed under the Apache License, Version 2.0 (the "License");
5*14675a02SAndroid Build Coastguard Worker * you may not use this file except in compliance with the License.
6*14675a02SAndroid Build Coastguard Worker * You may obtain a copy of the License at
7*14675a02SAndroid Build Coastguard Worker *
8*14675a02SAndroid Build Coastguard Worker * http://www.apache.org/licenses/LICENSE-2.0
9*14675a02SAndroid Build Coastguard Worker *
10*14675a02SAndroid Build Coastguard Worker * Unless required by applicable law or agreed to in writing, software
11*14675a02SAndroid Build Coastguard Worker * distributed under the License is distributed on an "AS IS" BASIS,
12*14675a02SAndroid Build Coastguard Worker * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*14675a02SAndroid Build Coastguard Worker * See the License for the specific language governing permissions and
14*14675a02SAndroid Build Coastguard Worker * limitations under the License.
15*14675a02SAndroid Build Coastguard Worker */
16*14675a02SAndroid Build Coastguard Worker
17*14675a02SAndroid Build Coastguard Worker #include "fcp/secagg/server/distribution_utilities.h"
18*14675a02SAndroid Build Coastguard Worker
19*14675a02SAndroid Build Coastguard Worker #include <cmath>
20*14675a02SAndroid Build Coastguard Worker #include <iostream>
21*14675a02SAndroid Build Coastguard Worker #include <memory>
22*14675a02SAndroid Build Coastguard Worker
23*14675a02SAndroid Build Coastguard Worker namespace fcp {
24*14675a02SAndroid Build Coastguard Worker namespace secagg {
25*14675a02SAndroid Build Coastguard Worker
26*14675a02SAndroid Build Coastguard Worker StatusOr<std::unique_ptr<HypergeometricDistribution>>
Create(int total,int marked,int sampled)27*14675a02SAndroid Build Coastguard Worker HypergeometricDistribution::Create(int total, int marked, int sampled) {
28*14675a02SAndroid Build Coastguard Worker if (total < 0) {
29*14675a02SAndroid Build Coastguard Worker return FCP_STATUS(FAILED_PRECONDITION)
30*14675a02SAndroid Build Coastguard Worker << "The population should be at least zero. Value provided = "
31*14675a02SAndroid Build Coastguard Worker << total;
32*14675a02SAndroid Build Coastguard Worker }
33*14675a02SAndroid Build Coastguard Worker if (marked < 0) {
34*14675a02SAndroid Build Coastguard Worker return FCP_STATUS(FAILED_PRECONDITION)
35*14675a02SAndroid Build Coastguard Worker << "The marked population should have size at least zero. Value "
36*14675a02SAndroid Build Coastguard Worker "provided = "
37*14675a02SAndroid Build Coastguard Worker << marked;
38*14675a02SAndroid Build Coastguard Worker }
39*14675a02SAndroid Build Coastguard Worker if (sampled < 0) {
40*14675a02SAndroid Build Coastguard Worker return FCP_STATUS(FAILED_PRECONDITION)
41*14675a02SAndroid Build Coastguard Worker << "The sample size should be at least zero. Value provided = "
42*14675a02SAndroid Build Coastguard Worker << sampled;
43*14675a02SAndroid Build Coastguard Worker }
44*14675a02SAndroid Build Coastguard Worker if (marked > total) {
45*14675a02SAndroid Build Coastguard Worker return FCP_STATUS(FAILED_PRECONDITION)
46*14675a02SAndroid Build Coastguard Worker << "The marked population " << marked
47*14675a02SAndroid Build Coastguard Worker << " should not exceed the total population " << total;
48*14675a02SAndroid Build Coastguard Worker }
49*14675a02SAndroid Build Coastguard Worker if (sampled > total) {
50*14675a02SAndroid Build Coastguard Worker return FCP_STATUS(FAILED_PRECONDITION)
51*14675a02SAndroid Build Coastguard Worker << "The sample size " << sampled
52*14675a02SAndroid Build Coastguard Worker << " should not exceed the total population " << total;
53*14675a02SAndroid Build Coastguard Worker }
54*14675a02SAndroid Build Coastguard Worker return std::unique_ptr<HypergeometricDistribution>(
55*14675a02SAndroid Build Coastguard Worker new HypergeometricDistribution(total, marked, sampled));
56*14675a02SAndroid Build Coastguard Worker }
57*14675a02SAndroid Build Coastguard Worker
PMF(double x)58*14675a02SAndroid Build Coastguard Worker double HypergeometricDistribution::PMF(double x) { return PMFImpl(x, marked_); }
59*14675a02SAndroid Build Coastguard Worker
PMFImpl(double x,int counted)60*14675a02SAndroid Build Coastguard Worker double HypergeometricDistribution::PMFImpl(double x, int counted) {
61*14675a02SAndroid Build Coastguard Worker if (x < 0 || x > sampled_ || x > counted) return 0;
62*14675a02SAndroid Build Coastguard Worker if (total_ + x < counted + sampled_) return 0;
63*14675a02SAndroid Build Coastguard Worker double lpmf = std::lgamma(sampled_ + 1) + std::lgamma(counted + 1) +
64*14675a02SAndroid Build Coastguard Worker std::lgamma(total_ - counted + 1) +
65*14675a02SAndroid Build Coastguard Worker std::lgamma(total_ - sampled_ + 1) - std::lgamma(x + 1) -
66*14675a02SAndroid Build Coastguard Worker std::lgamma(sampled_ - x + 1) - std::lgamma(counted - x + 1) -
67*14675a02SAndroid Build Coastguard Worker std::lgamma(total_ + 1) -
68*14675a02SAndroid Build Coastguard Worker std::lgamma(total_ - sampled_ - counted + x + 1);
69*14675a02SAndroid Build Coastguard Worker return std::exp(lpmf);
70*14675a02SAndroid Build Coastguard Worker }
71*14675a02SAndroid Build Coastguard Worker
CDF(double x)72*14675a02SAndroid Build Coastguard Worker double HypergeometricDistribution::CDF(double x) {
73*14675a02SAndroid Build Coastguard Worker x = std::floor(x);
74*14675a02SAndroid Build Coastguard Worker double mean = marked_ * static_cast<double>(sampled_) / total_;
75*14675a02SAndroid Build Coastguard Worker if (x > mean) {
76*14675a02SAndroid Build Coastguard Worker return 1 - CDFImpl(sampled_ - x - 1, total_ - marked_);
77*14675a02SAndroid Build Coastguard Worker } else {
78*14675a02SAndroid Build Coastguard Worker return CDFImpl(x, marked_);
79*14675a02SAndroid Build Coastguard Worker }
80*14675a02SAndroid Build Coastguard Worker }
81*14675a02SAndroid Build Coastguard Worker
CDFImpl(double x,int counted)82*14675a02SAndroid Build Coastguard Worker double HypergeometricDistribution::CDFImpl(double x, int counted) {
83*14675a02SAndroid Build Coastguard Worker double current_pmf = PMFImpl(x, counted);
84*14675a02SAndroid Build Coastguard Worker double result = 0;
85*14675a02SAndroid Build Coastguard Worker while (current_pmf > result * 1e-16) {
86*14675a02SAndroid Build Coastguard Worker result += current_pmf;
87*14675a02SAndroid Build Coastguard Worker current_pmf *= x;
88*14675a02SAndroid Build Coastguard Worker current_pmf *= total_ - counted - sampled_ + x;
89*14675a02SAndroid Build Coastguard Worker current_pmf /= counted - x + 1;
90*14675a02SAndroid Build Coastguard Worker current_pmf /= sampled_ - x + 1;
91*14675a02SAndroid Build Coastguard Worker --x;
92*14675a02SAndroid Build Coastguard Worker }
93*14675a02SAndroid Build Coastguard Worker return result;
94*14675a02SAndroid Build Coastguard Worker }
95*14675a02SAndroid Build Coastguard Worker
FindQuantile(double quantile,bool complement)96*14675a02SAndroid Build Coastguard Worker double HypergeometricDistribution::FindQuantile(double quantile,
97*14675a02SAndroid Build Coastguard Worker bool complement) {
98*14675a02SAndroid Build Coastguard Worker if (quantile > 0.5) {
99*14675a02SAndroid Build Coastguard Worker quantile = 1 - quantile;
100*14675a02SAndroid Build Coastguard Worker complement = !complement;
101*14675a02SAndroid Build Coastguard Worker }
102*14675a02SAndroid Build Coastguard Worker if (complement) {
103*14675a02SAndroid Build Coastguard Worker return sampled_ - FindQuantileImpl(quantile, total_ - marked_) - 1;
104*14675a02SAndroid Build Coastguard Worker } else {
105*14675a02SAndroid Build Coastguard Worker return FindQuantileImpl(quantile, marked_);
106*14675a02SAndroid Build Coastguard Worker }
107*14675a02SAndroid Build Coastguard Worker }
108*14675a02SAndroid Build Coastguard Worker
FindQuantileImpl(double quantile,int counted)109*14675a02SAndroid Build Coastguard Worker double HypergeometricDistribution::FindQuantileImpl(double quantile,
110*14675a02SAndroid Build Coastguard Worker int counted) {
111*14675a02SAndroid Build Coastguard Worker double basic_bound = counted + sampled_ - total_ - 1;
112*14675a02SAndroid Build Coastguard Worker // An inverted tail bound gives a lower bound on the result
113*14675a02SAndroid Build Coastguard Worker double fancy_bound =
114*14675a02SAndroid Build Coastguard Worker sampled_ * (static_cast<double>(counted) / total_ -
115*14675a02SAndroid Build Coastguard Worker std::sqrt(-std::log(quantile) / (2 * sampled_)));
116*14675a02SAndroid Build Coastguard Worker double result = -1;
117*14675a02SAndroid Build Coastguard Worker if (basic_bound > result) result = basic_bound;
118*14675a02SAndroid Build Coastguard Worker if (fancy_bound > result) result = fancy_bound;
119*14675a02SAndroid Build Coastguard Worker result = std::floor(result);
120*14675a02SAndroid Build Coastguard Worker
121*14675a02SAndroid Build Coastguard Worker double current_cdf = CDFImpl(result, counted);
122*14675a02SAndroid Build Coastguard Worker double current_pmf = PMFImpl(result, counted);
123*14675a02SAndroid Build Coastguard Worker while (current_cdf < quantile && result < sampled_) {
124*14675a02SAndroid Build Coastguard Worker if (current_pmf > 0) {
125*14675a02SAndroid Build Coastguard Worker current_pmf /= result + 1;
126*14675a02SAndroid Build Coastguard Worker current_pmf /= total_ - counted - sampled_ + result + 1;
127*14675a02SAndroid Build Coastguard Worker current_pmf *= counted - result;
128*14675a02SAndroid Build Coastguard Worker current_pmf *= sampled_ - result;
129*14675a02SAndroid Build Coastguard Worker } else {
130*14675a02SAndroid Build Coastguard Worker current_pmf = PMFImpl(result + 1, counted);
131*14675a02SAndroid Build Coastguard Worker }
132*14675a02SAndroid Build Coastguard Worker current_cdf += current_pmf;
133*14675a02SAndroid Build Coastguard Worker ++result;
134*14675a02SAndroid Build Coastguard Worker }
135*14675a02SAndroid Build Coastguard Worker --result;
136*14675a02SAndroid Build Coastguard Worker return result;
137*14675a02SAndroid Build Coastguard Worker }
138*14675a02SAndroid Build Coastguard Worker
139*14675a02SAndroid Build Coastguard Worker } // namespace secagg
140*14675a02SAndroid Build Coastguard Worker } // namespace fcp
141