xref: /aosp_15_r20/external/federated-compute/fcp/secagg/server/distribution_utilities.h (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2023 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef FCP_SECAGG_SERVER_DISTRIBUTION_UTILITIES_H_
18 #define FCP_SECAGG_SERVER_DISTRIBUTION_UTILITIES_H_
19 
20 #include <memory>
21 
22 #include "fcp/base/monitoring.h"
23 
24 namespace fcp {
25 namespace secagg {
26 
27 // Represents a Hypergeometric distribution with parameters fixed at creation of
28 // the object. Allows to query certain distribution functions.
29 class HypergeometricDistribution {
30  public:
31   static StatusOr<std::unique_ptr<HypergeometricDistribution>> Create(
32       int total, int marked, int sampled);
33 
34   // Evaluates the probability mass funciton of the random variable at x.
35   double PMF(double x);
36 
37   // Evaluates the cumulative distribution function of the random variable at x.
38   double CDF(double x);
39 
40   // Finds the value whose cdf is quantile rounded outwards to an integer.
41   // Setting complement to true is equivalent to setting quantile = 1 - quantile
42   // but can avoid numerical error in the extreme upper tail.
43   double FindQuantile(double quantile, bool complement = false);
44 
45  private:
46   const int total_;
47   const int marked_;
48   const int sampled_;
49 
HypergeometricDistribution(int total,int marked,int sampled)50   HypergeometricDistribution(int total, int marked, int sampled)
51       : total_(total), marked_(marked), sampled_(sampled) {}
52 
53   double PMFImpl(double x, int counted);
54 
55   double CDFImpl(double x, int counted);
56 
57   double FindQuantileImpl(double quantile, int counted);
58 };
59 
60 }  // namespace secagg
61 }  // namespace fcp
62 #endif  // FCP_SECAGG_SERVER_DISTRIBUTION_UTILITIES_H_
63