xref: /aosp_15_r20/external/rappor/analysis/R/association.R (revision 2abb31345f6c95944768b5222a9a5ed3fc68cc00)
1*2abb3134SXin Li# Copyright 2014 Google Inc. All rights reserved.
2*2abb3134SXin Li#
3*2abb3134SXin Li# Licensed under the Apache License, Version 2.0 (the "License");
4*2abb3134SXin Li# you may not use this file except in compliance with the License.
5*2abb3134SXin Li# You may obtain a copy of the License at
6*2abb3134SXin Li#
7*2abb3134SXin Li#     http://www.apache.org/licenses/LICENSE-2.0
8*2abb3134SXin Li#
9*2abb3134SXin Li# Unless required by applicable law or agreed to in writing, software
10*2abb3134SXin Li# distributed under the License is distributed on an "AS IS" BASIS,
11*2abb3134SXin Li# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12*2abb3134SXin Li# See the License for the specific language governing permissions and
13*2abb3134SXin Li# limitations under the License.
14*2abb3134SXin Li
15*2abb3134SXin Lilibrary(parallel)  # mclapply
16*2abb3134SXin Li
17*2abb3134SXin Lisource.rappor <- function(rel_path)  {
18*2abb3134SXin Li  abs_path <- paste0(Sys.getenv("RAPPOR_REPO", ""), rel_path)
19*2abb3134SXin Li  source(abs_path)
20*2abb3134SXin Li}
21*2abb3134SXin Li
22*2abb3134SXin Lisource.rappor("analysis/R/util.R")  # for Log
23*2abb3134SXin Lisource.rappor("analysis/R/decode.R")  # for ComputeCounts
24*2abb3134SXin Li
25*2abb3134SXin Li#
26*2abb3134SXin Li# Tools used to estimate variable distributions of up to three variables
27*2abb3134SXin Li#     in RAPPOR. This contains the functions relevant to estimating joint
28*2abb3134SXin Li#     distributions.
29*2abb3134SXin Li
30*2abb3134SXin LiGetOtherProbs <- function(counts, map_by_cohort, marginal, params, pstar,
31*2abb3134SXin Li                          qstar) {
32*2abb3134SXin Li  # Computes the marginal for the "other" category.
33*2abb3134SXin Li  #
34*2abb3134SXin Li  # Args:
35*2abb3134SXin Li  #   counts: m x (k+1) matrix with counts of each bit for each
36*2abb3134SXin Li  #       cohort (m=#cohorts total, k=# bits in bloom filter), first column
37*2abb3134SXin Li  #       stores the total counts
38*2abb3134SXin Li  #   map_by_cohort: list of matrices encoding locations of hashes for each
39*2abb3134SXin Li  #       string "other" category)
40*2abb3134SXin Li  #   marginal: object containing the estimated frequencies of known strings
41*2abb3134SXin Li  #       as well as the strings themselves, variance, etc.
42*2abb3134SXin Li  #   params: RAPPOR encoding parameters
43*2abb3134SXin Li  #
44*2abb3134SXin Li  # Returns:
45*2abb3134SXin Li  #   List of vectors of probabilities that each bit was set by the "other"
46*2abb3134SXin Li  #   category.  The list is indexed by cohort.
47*2abb3134SXin Li
48*2abb3134SXin Li  N <- sum(counts[, 1])
49*2abb3134SXin Li
50*2abb3134SXin Li  # Counts of known strings to remove from each cohort.
51*2abb3134SXin Li  known_counts <- ceiling(marginal$proportion * N / params$m)
52*2abb3134SXin Li  sum_known <- sum(known_counts)
53*2abb3134SXin Li
54*2abb3134SXin Li  # Select only the strings we care about from each cohort.
55*2abb3134SXin Li  # NOTE: drop = FALSE necessary if there is one candidate
56*2abb3134SXin Li  candidate_map <- lapply(map_by_cohort, function(map_for_cohort) {
57*2abb3134SXin Li    map_for_cohort[, marginal$string, drop = FALSE]
58*2abb3134SXin Li  })
59*2abb3134SXin Li
60*2abb3134SXin Li  # If no strings were found, all nonzero counts were set by "other"
61*2abb3134SXin Li  if (length(marginal) == 0) {
62*2abb3134SXin Li    probs_other <- apply(counts, 1, function(cohort_row) {
63*2abb3134SXin Li      cohort_row[-1] / cohort_row[1]
64*2abb3134SXin Li    })
65*2abb3134SXin Li    return(as.list(as.data.frame(probs_other)))
66*2abb3134SXin Li  }
67*2abb3134SXin Li
68*2abb3134SXin Li  # Counts set by known strings without noise considerations.
69*2abb3134SXin Li  known_counts_by_cohort <- sapply(candidate_map, function(map_for_cohort) {
70*2abb3134SXin Li    as.vector(as.matrix(map_for_cohort) %*% known_counts)
71*2abb3134SXin Li  })
72*2abb3134SXin Li
73*2abb3134SXin Li  # Protect against R's matrix/vector confusion.  This ensures
74*2abb3134SXin Li  # known_counts_by_cohort is a matrix in the k=1 case.
75*2abb3134SXin Li  dim(known_counts_by_cohort) <- c(params$m, params$k)
76*2abb3134SXin Li
77*2abb3134SXin Li  # Counts set by known vals zero bits adjusting by p plus true bits
78*2abb3134SXin Li  # adjusting by q.
79*2abb3134SXin Li  known_counts_by_cohort <- (sum_known - known_counts_by_cohort) * pstar +
80*2abb3134SXin Li                            known_counts_by_cohort * qstar
81*2abb3134SXin Li
82*2abb3134SXin Li  # Add the left hand sums to make it a m x (k+1) "counts" matrix
83*2abb3134SXin Li  known_counts_by_cohort <- cbind(sum_known, known_counts_by_cohort)
84*2abb3134SXin Li
85*2abb3134SXin Li  # Counts set by the "other" category.
86*2abb3134SXin Li  reduced_counts <- counts - known_counts_by_cohort
87*2abb3134SXin Li  reduced_counts[reduced_counts < 0] <- 0
88*2abb3134SXin Li  probs_other <- apply(reduced_counts, 1, function(cohort_row) {
89*2abb3134SXin Li    cohort_row[-1] / cohort_row[1]
90*2abb3134SXin Li  })
91*2abb3134SXin Li
92*2abb3134SXin Li  # Protect against R's matrix/vector confusion.
93*2abb3134SXin Li  dim(probs_other) <- c(params$k, params$m)
94*2abb3134SXin Li
95*2abb3134SXin Li  probs_other[probs_other > 1] <- 1
96*2abb3134SXin Li  probs_other[is.nan(probs_other)] <- 0
97*2abb3134SXin Li  probs_other[is.infinite(probs_other)] <- 0
98*2abb3134SXin Li
99*2abb3134SXin Li  # Convert it from a k x m matrix to a list indexed by m cohorts.
100*2abb3134SXin Li  # as.data.frame makes each cohort a column, which can be indexed by
101*2abb3134SXin Li  # probs_other[[cohort]].
102*2abb3134SXin Li  result <- as.list(as.data.frame(probs_other))
103*2abb3134SXin Li
104*2abb3134SXin Li  result
105*2abb3134SXin Li}
106*2abb3134SXin Li
107*2abb3134SXin LiGetCondProbBooleanReports <- function(reports, pstar, qstar, num_cores) {
108*2abb3134SXin Li  # Compute conditional probabilities given a set of Boolean reports.
109*2abb3134SXin Li  #
110*2abb3134SXin Li  # Args:
111*2abb3134SXin Li  #   reports: RAPPOR reports as a list of bit arrays (of length 1, because
112*2abb3134SXin Li  #   this is a boolean report)
113*2abb3134SXin Li  #   pstar, qstar: standard params computed from from rappor parameters
114*2abb3134SXin Li  #   num_cores: number of cores to pass to mclapply to parallelize apply
115*2abb3134SXin Li  #
116*2abb3134SXin Li  # Returns:
117*2abb3134SXin Li  #   Conditional probability of all boolean reports corresponding to
118*2abb3134SXin Li  #   candidates (TRUE, FALSE)
119*2abb3134SXin Li
120*2abb3134SXin Li  # The values below are p(report=1|value=TRUE), p(report=1|value=FALSE)
121*2abb3134SXin Li  cond_probs_for_1 <- c(qstar, pstar)
122*2abb3134SXin Li  # The values below are p(report=0|value=TRUE), p(report=0|value=FALSE)
123*2abb3134SXin Li  cond_probs_for_0 <- c(1 - qstar,  1 - pstar)
124*2abb3134SXin Li
125*2abb3134SXin Li  cond_report_dist <- mclapply(reports, function(report) {
126*2abb3134SXin Li    if (report[[1]] == 1) {
127*2abb3134SXin Li      cond_probs_for_1
128*2abb3134SXin Li    } else {
129*2abb3134SXin Li      cond_probs_for_0
130*2abb3134SXin Li    }
131*2abb3134SXin Li  }, mc.cores = num_cores)
132*2abb3134SXin Li  cond_report_dist
133*2abb3134SXin Li}
134*2abb3134SXin Li
135*2abb3134SXin LiGetCondProbStringReports <- function(reports, cohorts, map, m, pstar, qstar,
136*2abb3134SXin Li                                     marginal, prob_other = NULL, num_cores) {
137*2abb3134SXin Li  # Wrapper around GetCondProb. Given a set of reports, cohorts, map and
138*2abb3134SXin Li  # parameters m, p*, and q*, it first computes bit indices by cohort, and
139*2abb3134SXin Li  # then applies GetCondProb individually to each report.
140*2abb3134SXin Li  #
141*2abb3134SXin Li  # Args:
142*2abb3134SXin Li  #   reports: RAPPOR reports as a list of bit arrays
143*2abb3134SXin Li  #   cohorts: cohorts corresponding to these reports as a list
144*2abb3134SXin Li  #   map: map file
145*2abb3134SXin Li  #   m, pstar, qstar: standard params computed from from rappor parameters
146*2abb3134SXin Li  #   marginal: list containing marginal estimates (output of Decode)
147*2abb3134SXin Li  #   prob_other: vector of length k, indicating how often each bit in the
148*2abb3134SXin Li  #     Bloom filter was set by a string in the "other" category.
149*2abb3134SXin Li  #
150*2abb3134SXin Li  # Returns:
151*2abb3134SXin Li  #   Conditional probability of all reports given each of the strings in
152*2abb3134SXin Li  #   marginal$string
153*2abb3134SXin Li
154*2abb3134SXin Li  # Get bit indices that are set per candidate per cohort
155*2abb3134SXin Li  bit_indices_by_cohort <- lapply(1:m, function(cohort) {
156*2abb3134SXin Li    map_for_cohort <- map$map_by_cohort[[cohort]]
157*2abb3134SXin Li    # Find the bits set by the candidate strings
158*2abb3134SXin Li    bit_indices <- lapply(marginal$string, function(x) {
159*2abb3134SXin Li      which(map_for_cohort[, x])
160*2abb3134SXin Li    })
161*2abb3134SXin Li    bit_indices
162*2abb3134SXin Li  })
163*2abb3134SXin Li
164*2abb3134SXin Li  # Apply GetCondProb over all reports
165*2abb3134SXin Li  cond_report_dist <- mclapply(seq(length(reports)), function(i) {
166*2abb3134SXin Li    cohort <- cohorts[i]
167*2abb3134SXin Li    #Log('Report %d, cohort %d', i, cohort)
168*2abb3134SXin Li    bit_indices <- bit_indices_by_cohort[[cohort]]
169*2abb3134SXin Li    GetCondProb(reports[[i]], pstar, qstar, bit_indices,
170*2abb3134SXin Li                prob_other = prob_other[[cohort]])
171*2abb3134SXin Li  }, mc.cores = num_cores)
172*2abb3134SXin Li  cond_report_dist
173*2abb3134SXin Li}
174*2abb3134SXin Li
175*2abb3134SXin Li
176*2abb3134SXin LiGetCondProb <- function(report, pstar, qstar, bit_indices, prob_other = NULL) {
177*2abb3134SXin Li  # Given the observed bit array, estimate P(report | true value).
178*2abb3134SXin Li  # Probabilities are estimated for all truth values.
179*2abb3134SXin Li  #
180*2abb3134SXin Li  # Args:
181*2abb3134SXin Li  #   report: A single observed RAPPOR report (binary vector of length k).
182*2abb3134SXin Li  #   params: RAPPOR parameters.
183*2abb3134SXin Li  #   bit_indices: list with one entry for each candidate.  Each entry is an
184*2abb3134SXin Li  #     integer vector of length h, specifying which bits are set for the
185*2abb3134SXin Li  #     candidate in the report's cohort.
186*2abb3134SXin Li  #   prob_other: vector of length k, indicating how often each bit in the
187*2abb3134SXin Li  #     Bloom filter was set by a string in the "other" category.
188*2abb3134SXin Li  #
189*2abb3134SXin Li  # Returns:
190*2abb3134SXin Li  #   Conditional probability of report given each of the strings in
191*2abb3134SXin Li  #       candidate_strings
192*2abb3134SXin Li  ones <- sum(report)
193*2abb3134SXin Li  zeros <- length(report) - ones
194*2abb3134SXin Li  probs <- ifelse(report == 1, pstar, 1 - pstar)
195*2abb3134SXin Li
196*2abb3134SXin Li  # Find the likelihood of report given each candidate string
197*2abb3134SXin Li  prob_obs_vals <- sapply(bit_indices, function(x) {
198*2abb3134SXin Li    prod(c(probs[-x], ifelse(report[x] == 1, qstar, 1 - qstar)))
199*2abb3134SXin Li  })
200*2abb3134SXin Li
201*2abb3134SXin Li  # Account for the "other" category
202*2abb3134SXin Li  if (!is.null(prob_other)) {
203*2abb3134SXin Li    prob_other <- prod(c(prob_other[which(report == 1)],
204*2abb3134SXin Li                         (1 - prob_other)[which(report == 0)]))
205*2abb3134SXin Li    c(prob_obs_vals, prob_other)
206*2abb3134SXin Li  } else {
207*2abb3134SXin Li    prob_obs_vals
208*2abb3134SXin Li  }
209*2abb3134SXin Li}
210*2abb3134SXin Li
211*2abb3134SXin LiUpdatePij <- function(pij, cond_prob) {
212*2abb3134SXin Li  # Update the probability matrix based on the EM algorithm.
213*2abb3134SXin Li  #
214*2abb3134SXin Li  # Args:
215*2abb3134SXin Li  #   pij: conditional distribution of x (vector)
216*2abb3134SXin Li  #   cond_prob: conditional distribution computed previously
217*2abb3134SXin Li  #
218*2abb3134SXin Li  # Returns:
219*2abb3134SXin Li  #   Updated pijs from em algorithm (maximization)
220*2abb3134SXin Li
221*2abb3134SXin Li  # NOTE: Not using mclapply here because we have a faster C++ implementation.
222*2abb3134SXin Li  # mclapply spawns multiple processes, and each process can take up 3 GB+ or 5
223*2abb3134SXin Li  # GB+ of memory.
224*2abb3134SXin Li  wcp <- lapply(cond_prob, function(x) {
225*2abb3134SXin Li    z <- x * pij
226*2abb3134SXin Li    z <- z / sum(z)
227*2abb3134SXin Li    z[is.nan(z)] <- 0
228*2abb3134SXin Li    z
229*2abb3134SXin Li  })
230*2abb3134SXin Li  Reduce("+", wcp) / length(wcp)
231*2abb3134SXin Li}
232*2abb3134SXin Li
233*2abb3134SXin LiComputeVar <- function(cond_prob, est) {
234*2abb3134SXin Li  # Computes the variance of the estimated pij's.
235*2abb3134SXin Li  #
236*2abb3134SXin Li  # Args:
237*2abb3134SXin Li  #   cond_prob: conditional distribution computed previously
238*2abb3134SXin Li  #   est: estimated pij's
239*2abb3134SXin Li  #
240*2abb3134SXin Li  # Returns:
241*2abb3134SXin Li  #   Variance of the estimated pij's
242*2abb3134SXin Li
243*2abb3134SXin Li  inform <- Reduce("+", lapply(cond_prob, function(x) {
244*2abb3134SXin Li    (outer(as.vector(x), as.vector(x))) / (sum(x * est))^2
245*2abb3134SXin Li  }))
246*2abb3134SXin Li  var_cov <- solve(inform)
247*2abb3134SXin Li  sd <- matrix(sqrt(diag(var_cov)), dim(cond_prob[[1]]))
248*2abb3134SXin Li  list(var_cov = var_cov, sd = sd, inform = inform)
249*2abb3134SXin Li}
250*2abb3134SXin Li
251*2abb3134SXin LiEM <- function(cond_prob, starting_pij = NULL, estimate_var = FALSE,
252*2abb3134SXin Li               max_em_iters = 1000, epsilon = 10^-6, verbose = FALSE) {
253*2abb3134SXin Li  # Performs estimation.
254*2abb3134SXin Li  #
255*2abb3134SXin Li  # Args:
256*2abb3134SXin Li  #   cond_prob: conditional distribution computed previously
257*2abb3134SXin Li  #   starting_pij: estimated pij's
258*2abb3134SXin Li  #   estimate_var: flags whether we should estimate the variance
259*2abb3134SXin Li  #       of our computed distribution
260*2abb3134SXin Li  #   max_em_iters: maximum number of EM iterations
261*2abb3134SXin Li  #   epsilon: convergence parameter
262*2abb3134SXin Li  #   verbose: flags whether to display error data
263*2abb3134SXin Li  #
264*2abb3134SXin Li  # Returns:
265*2abb3134SXin Li  #   Estimated pij's, variance, error params
266*2abb3134SXin Li
267*2abb3134SXin Li  pij <- list()
268*2abb3134SXin Li  state_space <- dim(cond_prob[[1]])
269*2abb3134SXin Li  if (is.null(starting_pij)) {
270*2abb3134SXin Li    pij[[1]] <- array(1 / prod(state_space), state_space)
271*2abb3134SXin Li  } else {
272*2abb3134SXin Li    pij[[1]] <- starting_pij
273*2abb3134SXin Li  }
274*2abb3134SXin Li
275*2abb3134SXin Li  i <- 0  # visible outside loop
276*2abb3134SXin Li  if (nrow(pij[[1]]) > 0) {
277*2abb3134SXin Li    # Run EM
278*2abb3134SXin Li    for (i in 1:max_em_iters) {
279*2abb3134SXin Li      pij[[i + 1]] <- UpdatePij(pij[[i]], cond_prob)
280*2abb3134SXin Li      dif <- max(abs(pij[[i + 1]] - pij[[i]]))
281*2abb3134SXin Li      if (dif < epsilon) {
282*2abb3134SXin Li        break
283*2abb3134SXin Li      }
284*2abb3134SXin Li      Log('EM iteration %d, dif = %e', i, dif)
285*2abb3134SXin Li    }
286*2abb3134SXin Li  }
287*2abb3134SXin Li  # Compute the variance of the estimate.
288*2abb3134SXin Li  est <- pij[[length(pij)]]
289*2abb3134SXin Li  if (estimate_var) {
290*2abb3134SXin Li    var_cov <- ComputeVar(cond_prob, est)
291*2abb3134SXin Li    sd <- var_cov$sd
292*2abb3134SXin Li    inform <- var_cov$inform
293*2abb3134SXin Li    var_cov <- var_cov$var_cov
294*2abb3134SXin Li  } else {
295*2abb3134SXin Li    var_cov <- NULL
296*2abb3134SXin Li    inform <- NULL
297*2abb3134SXin Li    sd <- NULL
298*2abb3134SXin Li  }
299*2abb3134SXin Li  list(est = est, sd = sd, var_cov = var_cov, hist = pij, num_em_iters = i)
300*2abb3134SXin Li}
301*2abb3134SXin Li
302*2abb3134SXin LiTestIndependence <- function(est, inform) {
303*2abb3134SXin Li  # Tests the degree of independence between variables.
304*2abb3134SXin Li  #
305*2abb3134SXin Li  # Args:
306*2abb3134SXin Li  #   est: esimated pij values
307*2abb3134SXin Li  #   inform: information matrix
308*2abb3134SXin Li  #
309*2abb3134SXin Li  # Returns:
310*2abb3134SXin Li  #   Chi-squared statistic for whether two variables are independent
311*2abb3134SXin Li
312*2abb3134SXin Li  expec <- outer(apply(est, 1, sum), apply(est, 2, sum))
313*2abb3134SXin Li  diffs <- matrix(est - expec, ncol = 1)
314*2abb3134SXin Li  stat <- t(diffs) %*% inform %*% diffs
315*2abb3134SXin Li  df <- (nrow(est) - 1) * (ncol(est) - 1)
316*2abb3134SXin Li  list(stat = stat, pval = pchisq(stat, df, lower = FALSE))
317*2abb3134SXin Li}
318*2abb3134SXin Li
319*2abb3134SXin LiUpdateJointConditional <- function(cond_report_dist, joint_conditional = NULL) {
320*2abb3134SXin Li  # Updates the joint conditional  distribution of d variables, where
321*2abb3134SXin Li  #     num_variables is chosen by the client. Since variables are conditionally
322*2abb3134SXin Li  #     independent of one another, this is basically an outer product.
323*2abb3134SXin Li  #
324*2abb3134SXin Li  # Args:
325*2abb3134SXin Li  #   joint_conditional: The current state of the joint conditional
326*2abb3134SXin Li  #       distribution. This is a list with as many elements as there
327*2abb3134SXin Li  #       are reports.
328*2abb3134SXin Li  #   cond_report_dist: The conditional distribution of variable x, which will
329*2abb3134SXin Li  #       be outer-producted with the current joint conditional.
330*2abb3134SXin Li  #
331*2abb3134SXin Li  # Returns:
332*2abb3134SXin Li  #   A list of same length as joint_conditional containing the joint
333*2abb3134SXin Li  #       conditional distribution of all variables. If I want
334*2abb3134SXin Li  #       P(X'=x',Y=y'|X=x,Y=y), I will look at
335*2abb3134SXin Li  #       joint_conditional[x,x',y,y'].
336*2abb3134SXin Li
337*2abb3134SXin Li  if (is.null(joint_conditional)) {
338*2abb3134SXin Li    lapply(cond_report_dist, function(x) array(x))
339*2abb3134SXin Li  } else {
340*2abb3134SXin Li    mapply("outer", joint_conditional, cond_report_dist,
341*2abb3134SXin Li           SIMPLIFY = FALSE)
342*2abb3134SXin Li  }
343*2abb3134SXin Li}
344*2abb3134SXin Li
345*2abb3134SXin LiComputeDistributionEM <- function(reports, report_cohorts, maps,
346*2abb3134SXin Li                                  ignore_other = FALSE,
347*2abb3134SXin Li                                  params = NULL,
348*2abb3134SXin Li                                  params_list = NULL,
349*2abb3134SXin Li                                  marginals = NULL,
350*2abb3134SXin Li                                  estimate_var = FALSE,
351*2abb3134SXin Li                                  num_cores = 10,
352*2abb3134SXin Li                                  em_iter_func = EM,
353*2abb3134SXin Li                                  max_em_iters = 1000) {
354*2abb3134SXin Li  # Computes the distribution of num_variables variables, where
355*2abb3134SXin Li  #     num_variables is chosen by the client, using the EM algorithm.
356*2abb3134SXin Li  #
357*2abb3134SXin Li  # Args:
358*2abb3134SXin Li  #   reports: A list of num_variables elements, each a 2-dimensional array
359*2abb3134SXin Li  #       containing the counts of each bin for each report
360*2abb3134SXin Li  #   report_cohorts: A num_variables-element list; the ith element is an array
361*2abb3134SXin Li  #       containing the cohort of jth report for ith variable.
362*2abb3134SXin Li  #   maps: A num_variables-element list containing the map for each variable
363*2abb3134SXin Li  #   ignore_other: A boolean describing whether to compute the "other" category
364*2abb3134SXin Li  #   params: RAPPOR encoding parameters.  If set, all variables are assumed to
365*2abb3134SXin Li  #       be encoded with these parameters.
366*2abb3134SXin Li  #   params_list: A list of num_variables elements, each of which is the
367*2abb3134SXin Li  #       RAPPOR encoding parameters for a variable (a list itself).  If set,
368*2abb3134SXin Li  #       it must be the same length as 'reports'.
369*2abb3134SXin Li  #   marginals: List of estimated marginals for each variable
370*2abb3134SXin Li  #   estimate_var: A flag telling whether to estimate the variance.
371*2abb3134SXin Li  #   em_iter_func: Function that implements the iterative EM algorithm.
372*2abb3134SXin Li
373*2abb3134SXin Li  # Handle the case that the client wants to find the joint distribution of too
374*2abb3134SXin Li  # many variables.
375*2abb3134SXin Li  num_variables <- length(reports)
376*2abb3134SXin Li
377*2abb3134SXin Li  if (is.null(params) && is.null(params_list)) {
378*2abb3134SXin Li    stop("Either params or params_list must be passed")
379*2abb3134SXin Li  }
380*2abb3134SXin Li
381*2abb3134SXin Li  Log('Computing joint conditional')
382*2abb3134SXin Li
383*2abb3134SXin Li  # Compute the counts for each variable and then do conditionals.
384*2abb3134SXin Li  joint_conditional = NULL
385*2abb3134SXin Li  found_strings <- list()
386*2abb3134SXin Li
387*2abb3134SXin Li  for (j in (1:num_variables)) {
388*2abb3134SXin Li    Log('Processing var %d', j)
389*2abb3134SXin Li
390*2abb3134SXin Li    var_report <- reports[[j]]
391*2abb3134SXin Li    var_cohort <- report_cohorts[[j]]
392*2abb3134SXin Li    var_map <- maps[[j]]
393*2abb3134SXin Li    if (!is.null(params)) {
394*2abb3134SXin Li      var_params <- params
395*2abb3134SXin Li    } else {
396*2abb3134SXin Li      var_params <- params_list[[j]]
397*2abb3134SXin Li    }
398*2abb3134SXin Li
399*2abb3134SXin Li    var_counts <- NULL
400*2abb3134SXin Li    if (is.null(marginals)) {
401*2abb3134SXin Li      Log('\tSumming bits to gets observed counts')
402*2abb3134SXin Li      var_counts <- ComputeCounts(var_report, var_cohort, var_params)
403*2abb3134SXin Li
404*2abb3134SXin Li      Log('\tDecoding marginal')
405*2abb3134SXin Li      marginal <- Decode(var_counts, var_map$all_cohorts_map, var_params,
406*2abb3134SXin Li                         quiet = TRUE)$fit
407*2abb3134SXin Li      Log('\tMarginal for var %d has %d values:', j, nrow(marginal))
408*2abb3134SXin Li      print(marginal[, c('estimate', 'proportion')])  # rownames are the string
409*2abb3134SXin Li      cat('\n')
410*2abb3134SXin Li
411*2abb3134SXin Li      if (nrow(marginal) == 0) {
412*2abb3134SXin Li        Log('ERROR: Nothing decoded for variable %d', j)
413*2abb3134SXin Li        return (NULL)
414*2abb3134SXin Li      }
415*2abb3134SXin Li    } else {
416*2abb3134SXin Li      marginal <- marginals[[j]]
417*2abb3134SXin Li    }
418*2abb3134SXin Li    found_strings[[j]] <- marginal$string
419*2abb3134SXin Li
420*2abb3134SXin Li    p <- var_params$p
421*2abb3134SXin Li    q <- var_params$q
422*2abb3134SXin Li    f <- var_params$f
423*2abb3134SXin Li    # pstar and qstar needed to compute other probabilities as well as for
424*2abb3134SXin Li    # inputs to GetCondProb{Boolean, String}Reports subsequently
425*2abb3134SXin Li    pstar <- (1 - f / 2) * p + (f / 2) * q
426*2abb3134SXin Li    qstar <- (1 - f / 2) * q + (f / 2) * p
427*2abb3134SXin Li    k <- var_params$k
428*2abb3134SXin Li
429*2abb3134SXin Li    # Ignore other probability if either ignore_other is set or k == 1
430*2abb3134SXin Li    # (Boolean RAPPOR)
431*2abb3134SXin Li    if (ignore_other || (k == 1)) {
432*2abb3134SXin Li      prob_other <- vector(mode = "list", length = var_params$m)
433*2abb3134SXin Li    } else {
434*2abb3134SXin Li      # Compute the probability of the "other" category
435*2abb3134SXin Li      if (is.null(var_counts)) {
436*2abb3134SXin Li        var_counts <- ComputeCounts(var_report, var_cohort, var_params)
437*2abb3134SXin Li      }
438*2abb3134SXin Li      prob_other <- GetOtherProbs(var_counts, var_map$map_by_cohort, marginal,
439*2abb3134SXin Li                                  var_params, pstar, qstar)
440*2abb3134SXin Li      found_strings[[j]] <- c(found_strings[[j]], "Other")
441*2abb3134SXin Li    }
442*2abb3134SXin Li
443*2abb3134SXin Li    # Get the joint conditional distribution
444*2abb3134SXin Li    Log('\tGetCondProb for each report (%d cores)', num_cores)
445*2abb3134SXin Li
446*2abb3134SXin Li    # TODO(pseudorandom): check RAPPOR type more systematically instead of by
447*2abb3134SXin Li    # checking if k == 1
448*2abb3134SXin Li    if (k == 1) {
449*2abb3134SXin Li      cond_report_dist <- GetCondProbBooleanReports(var_report, pstar, qstar,
450*2abb3134SXin Li                                                    num_cores)
451*2abb3134SXin Li    } else {
452*2abb3134SXin Li      cond_report_dist <- GetCondProbStringReports(var_report,
453*2abb3134SXin Li                                var_cohort, var_map, var_params$m, pstar, qstar,
454*2abb3134SXin Li                                marginal, prob_other, num_cores)
455*2abb3134SXin Li    }
456*2abb3134SXin Li
457*2abb3134SXin Li    Log('\tUpdateJointConditional')
458*2abb3134SXin Li
459*2abb3134SXin Li    # Update the joint conditional distribution of all variables
460*2abb3134SXin Li    joint_conditional <- UpdateJointConditional(cond_report_dist,
461*2abb3134SXin Li                                                joint_conditional)
462*2abb3134SXin Li  }
463*2abb3134SXin Li
464*2abb3134SXin Li  N <- length(joint_conditional)
465*2abb3134SXin Li  dimensions <- dim(joint_conditional[[1]])
466*2abb3134SXin Li  # e.g. 2 x 3
467*2abb3134SXin Li  dimensions_str <- paste(dimensions, collapse = ' x ')
468*2abb3134SXin Li  total_entries <- prod(c(N, dimensions))
469*2abb3134SXin Li
470*2abb3134SXin Li  Log('Starting EM with N = %d matrices of size %s (%d entries)',
471*2abb3134SXin Li      N, dimensions_str, total_entries)
472*2abb3134SXin Li
473*2abb3134SXin Li  start_time <- proc.time()[['elapsed']]
474*2abb3134SXin Li
475*2abb3134SXin Li  # Run expectation maximization to find joint distribution
476*2abb3134SXin Li  em <- em_iter_func(joint_conditional, max_em_iters=max_em_iters,
477*2abb3134SXin Li                     epsilon = 10 ^ -6, verbose = FALSE,
478*2abb3134SXin Li                     estimate_var = estimate_var)
479*2abb3134SXin Li
480*2abb3134SXin Li  em_elapsed_time <- proc.time()[['elapsed']] - start_time
481*2abb3134SXin Li
482*2abb3134SXin Li  dimnames(em$est) <- found_strings
483*2abb3134SXin Li  # Return results in a usable format
484*2abb3134SXin Li  list(fit = em$est,
485*2abb3134SXin Li       sd = em$sd,
486*2abb3134SXin Li       em_elapsed_time = em_elapsed_time,
487*2abb3134SXin Li       num_em_iters = em$num_em_iters,
488*2abb3134SXin Li       # This last field is implementation-specific; it can be used for
489*2abb3134SXin Li       # interactive debugging.
490*2abb3134SXin Li       em = em)
491*2abb3134SXin Li}
492