1 /*
2 * Copyright 2023 Google LLC
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "fcp/secagg/server/distribution_utilities.h"
18
19 #include <memory>
20
21 #include "gmock/gmock.h"
22 #include "gtest/gtest.h"
23 #include "fcp/base/monitoring.h"
24 #include "fcp/testing/testing.h"
25
26 namespace fcp {
27 namespace secagg {
28 namespace {
29
30 struct HypergeometricCDForPMFInstance {
31 const double x;
32 const int total;
33 const int marked;
34 const int sampled;
35 const double probability;
36 };
37
38 struct HypergeometricQuantileInstance {
39 const double probability;
40 const int total;
41 const int marked;
42 const int sampled;
43 const int lower;
44 const int upper;
45 };
46
47 class HypergeometricPMF
48 : public ::testing::TestWithParam<HypergeometricCDForPMFInstance> {};
49
50 class HypergeometricCDF
51 : public ::testing::TestWithParam<HypergeometricCDForPMFInstance> {};
52
53 class HypergeometricQuantile
54 : public ::testing::TestWithParam<HypergeometricQuantileInstance> {};
55
TEST(HypergeometricDistributionCreate,RejectsInvalidInputs)56 TEST(HypergeometricDistributionCreate, RejectsInvalidInputs) {
57 ASSERT_FALSE(HypergeometricDistribution::Create(10, 11, 5).ok());
58 ASSERT_FALSE(HypergeometricDistribution::Create(10, 5, 11).ok());
59 ASSERT_FALSE(HypergeometricDistribution::Create(10, -1, 5).ok());
60 ASSERT_FALSE(HypergeometricDistribution::Create(10, 5, -1).ok());
61 ASSERT_FALSE(HypergeometricDistribution::Create(-10, 5, 5).ok());
62 ASSERT_FALSE(HypergeometricDistribution::Create(-10, -5, -5).ok());
63 }
64
TEST_P(HypergeometricPMF,ReturnsPrecomputedValues)65 TEST_P(HypergeometricPMF, ReturnsPrecomputedValues) {
66 const HypergeometricCDForPMFInstance& test_params = GetParam();
67 FCP_LOG(INFO) << "Testing hypergeometric pmf with x = " << test_params.x
68 << " total = " << test_params.total
69 << " marked = " << test_params.marked
70 << " sampled = " << test_params.sampled << ".";
71 auto p = HypergeometricDistribution::Create(
72 test_params.total, test_params.marked, test_params.sampled);
73 ASSERT_THAT(p, IsOk());
74 double result = p.value()->PMF(test_params.x);
75 double relative_error =
76 abs(result - test_params.probability) / (test_params.probability + 1e-30);
77 EXPECT_LT(relative_error, 1e-9);
78 FCP_LOG(INFO) << "result = " << result
79 << " expected_result = " << test_params.probability
80 << " relative_error" << relative_error;
81 }
82
83 INSTANTIATE_TEST_SUITE_P(HypergeometricPMFTests, HypergeometricPMF,
84 ::testing::ValuesIn<HypergeometricCDForPMFInstance>(
85 {{-5, 9, 3, 3, 0.0},
86 {17, 9, 3, 3, 0.0},
87 {0, 10, 0, 5, 1.0},
88 {3, 10, 10, 5, 0.0},
89 {4, 15, 6, 12, 0.2967032967032967},
90 {38, 98, 63, 17, 0.0},
91 {2, 187, 105, 43, 5.423847289689941e-16},
92 {40, 980, 392, 103, 0.08225792329713294},
93 {89, 1489, 312, 370, 0.014089199026838601},
94 {100000, 1000000, 200000, 500000,
95 0.0019947087839501726}}));
96
TEST_P(HypergeometricCDF,ReturnsPrecomputedValues)97 TEST_P(HypergeometricCDF, ReturnsPrecomputedValues) {
98 const HypergeometricCDForPMFInstance& test_params = GetParam();
99 FCP_LOG(INFO) << "Testing hypergeometric cdf with x = " << test_params.x
100 << " total = " << test_params.total
101 << " marked = " << test_params.marked
102 << " sampled = " << test_params.sampled << ".";
103 auto p = HypergeometricDistribution::Create(
104 test_params.total, test_params.marked, test_params.sampled);
105 ASSERT_THAT(p, IsOk());
106 double result = p.value()->CDF(test_params.x);
107 double relative_error =
108 abs(result - test_params.probability) / (test_params.probability + 1e-30);
109 EXPECT_LT(relative_error, 1e-9);
110 FCP_LOG(INFO) << "result = " << result
111 << " expected_result = " << test_params.probability
112 << " relative_error" << relative_error;
113 }
114
115 INSTANTIATE_TEST_SUITE_P(HypergeometricCDFTests, HypergeometricCDF,
116 ::testing::ValuesIn<HypergeometricCDForPMFInstance>(
117 {{-5, 9, 3, 3, 0.0},
118 {17, 9, 3, 3, 1.0},
119 {0, 10, 0, 5, 1.0},
120 {3, 10, 10, 5, 0.0},
121 {4.5, 15, 6, 12, 0.34065934065934067},
122 {38, 98, 63, 17, 1.0},
123 {2, 187, 105, 43, 5.526570670097338e-16},
124 {40, 980, 392, 103, 0.4430562850817352},
125 {89, 1489, 312, 370, 0.9599670222722507},
126 {100000, 1000000, 200000, 500000,
127 0.5009973543919738}}));
128
TEST_P(HypergeometricQuantile,ReturnsPrecomputedValues)129 TEST_P(HypergeometricQuantile, ReturnsPrecomputedValues) {
130 const HypergeometricQuantileInstance& test_params = GetParam();
131 FCP_LOG(INFO) << "Testing hypergeometric quantile with probability = "
132 << test_params.probability << " total = " << test_params.total
133 << " marked = " << test_params.marked
134 << " sampled = " << test_params.sampled << ".";
135 auto p = HypergeometricDistribution::Create(
136 test_params.total, test_params.marked, test_params.sampled);
137 ASSERT_THAT(p, IsOk());
138 double result_lower = p.value()->FindQuantile(test_params.probability);
139 EXPECT_GE(result_lower, test_params.lower);
140 EXPECT_LE(result_lower, test_params.lower + 1);
141 FCP_LOG(INFO) << "Lower result = " << result_lower
142 << " which should be between " << test_params.lower << " and "
143 << test_params.lower + 1 << ".";
144 double result_upper = p.value()->FindQuantile(test_params.probability, true);
145 EXPECT_LE(result_upper, test_params.upper);
146 EXPECT_GE(result_upper, test_params.upper - 1);
147 FCP_LOG(INFO) << "Upper result = " << result_upper
148 << " which should be between " << test_params.upper - 1
149 << " and " << test_params.upper << ".";
150 }
151
152 INSTANTIATE_TEST_SUITE_P(HypergeometricQuantileTests, HypergeometricQuantile,
153 ::testing::ValuesIn<HypergeometricQuantileInstance>(
154 {{0.5, 10, 0, 5, -1, 0},
155 {0.2, 10, 10, 5, 4, 5},
156 {0.97, 15, 6, 12, 5, 3},
157 {0.0001, 98, 63, 17, 3, 17},
158 {1e-05, 187, 105, 43, 11, 36},
159 {3e-08, 980, 392, 103, 16, 67},
160 {1.1e-09, 1489, 312, 370, 38, 119},
161 {1e-18, 1000000, 200000, 500000, 98248,
162 101751}}));
163
164 } // namespace
165 } // namespace secagg
166 } // namespace fcp
167