xref: /aosp_15_r20/external/rappor/analysis/R/fast_em.R (revision 2abb31345f6c95944768b5222a9a5ed3fc68cc00)
1*2abb3134SXin Li# fast_em.R: Wrapper around analysis/cpp/fast_em.cc.
2*2abb3134SXin Li#
3*2abb3134SXin Li# This serializes the input, shells out, and deserializes the output.
4*2abb3134SXin Li
5*2abb3134SXin Li.Flatten <- function(list_of_matrices) {
6*2abb3134SXin Li  listOfVectors <- lapply(list_of_matrices, as.vector)
7*2abb3134SXin Li  #print(listOfVectors)
8*2abb3134SXin Li
9*2abb3134SXin Li  # unlist takes list to vector.
10*2abb3134SXin Li  unlist(listOfVectors)
11*2abb3134SXin Li}
12*2abb3134SXin Li
13*2abb3134SXin Li.WriteListOfMatrices <- function(list_of_matrices, f) {
14*2abb3134SXin Li  flattened <- .Flatten(list_of_matrices)
15*2abb3134SXin Li
16*2abb3134SXin Li  # NOTE: UpdateJointConditional does outer product of dimensions!
17*2abb3134SXin Li
18*2abb3134SXin Li  # 3 letter strings are null terminated
19*2abb3134SXin Li  writeBin('ne ', con = f)
20*2abb3134SXin Li  num_entries <- length(list_of_matrices)
21*2abb3134SXin Li  writeBin(num_entries, con = f)
22*2abb3134SXin Li
23*2abb3134SXin Li  Log('Wrote num_entries = %d', num_entries)
24*2abb3134SXin Li
25*2abb3134SXin Li  # For 2x3, this is 6
26*2abb3134SXin Li  writeBin('es ', con = f)
27*2abb3134SXin Li
28*2abb3134SXin Li  entry_size <- as.integer(prod(dim(list_of_matrices[[1]])))
29*2abb3134SXin Li  writeBin(entry_size, con = f)
30*2abb3134SXin Li
31*2abb3134SXin Li  Log('Wrote entry_size = %d', entry_size)
32*2abb3134SXin Li
33*2abb3134SXin Li  # now write the data
34*2abb3134SXin Li  writeBin('dat', con = f)
35*2abb3134SXin Li  writeBin(flattened, con = f)
36*2abb3134SXin Li}
37*2abb3134SXin Li
38*2abb3134SXin Li.ExpectTag <- function(f, tag) {
39*2abb3134SXin Li  # Read a single NUL-terminated character string.
40*2abb3134SXin Li  actual <- readBin(con = f, what = "char", n = 1)
41*2abb3134SXin Li
42*2abb3134SXin Li  # Assert that we got what was expected.
43*2abb3134SXin Li  if (length(actual) != 1) {
44*2abb3134SXin Li    stop(sprintf("Failed to read a tag '%s'", tag))
45*2abb3134SXin Li  }
46*2abb3134SXin Li  if (actual != tag) {
47*2abb3134SXin Li    stop(sprintf("Expected '%s', got '%s'", tag, actual))
48*2abb3134SXin Li  }
49*2abb3134SXin Li}
50*2abb3134SXin Li
51*2abb3134SXin Li.ReadResult <- function (f, entry_size, matrix_dims) {
52*2abb3134SXin Li  .ExpectTag(f, "emi")
53*2abb3134SXin Li  # NOTE: assuming R integers are 4 bytes (uint32_t)
54*2abb3134SXin Li  num_em_iters <- readBin(con = f, what = "int", n = 1)
55*2abb3134SXin Li
56*2abb3134SXin Li  .ExpectTag(f, "pij")
57*2abb3134SXin Li  pij <- readBin(con = f, what = "double", n = entry_size)
58*2abb3134SXin Li
59*2abb3134SXin Li  # Adjust dimensions
60*2abb3134SXin Li  dim(pij) <- matrix_dims
61*2abb3134SXin Li
62*2abb3134SXin Li  Log("Number of EM iterations: %d", num_em_iters)
63*2abb3134SXin Li  Log("PIJ read from external implementation:")
64*2abb3134SXin Li  print(pij)
65*2abb3134SXin Li
66*2abb3134SXin Li  # est, sd, var_cov, hist
67*2abb3134SXin Li  list(est = pij, num_em_iters = num_em_iters)
68*2abb3134SXin Li}
69*2abb3134SXin Li
70*2abb3134SXin Li.SanityChecks <- function(joint_conditional) {
71*2abb3134SXin Li  # Display some stats before sending it over to C++.
72*2abb3134SXin Li
73*2abb3134SXin Li  inf_counts <- lapply(joint_conditional, function(m) {
74*2abb3134SXin Li    sum(m == Inf)
75*2abb3134SXin Li  })
76*2abb3134SXin Li  total_inf <- sum(as.numeric(inf_counts))
77*2abb3134SXin Li
78*2abb3134SXin Li  nan_counts <- lapply(joint_conditional, function(m) {
79*2abb3134SXin Li    sum(is.nan(m))
80*2abb3134SXin Li  })
81*2abb3134SXin Li  total_nan <- sum(as.numeric(nan_counts))
82*2abb3134SXin Li
83*2abb3134SXin Li  zero_counts <- lapply(joint_conditional, function(m) {
84*2abb3134SXin Li    sum(m == 0.0)
85*2abb3134SXin Li  })
86*2abb3134SXin Li  total_zero <- sum(as.numeric(zero_counts))
87*2abb3134SXin Li
88*2abb3134SXin Li  #sum(joint_conditional[joint_conditional == Inf, ])
89*2abb3134SXin Li  Log('total inf: %s', total_inf)
90*2abb3134SXin Li  Log('total nan: %s', total_nan)
91*2abb3134SXin Li  Log('total zero: %s', total_zero)
92*2abb3134SXin Li}
93*2abb3134SXin Li
94*2abb3134SXin LiConstructFastEM <- function(em_executable, tmp_dir) {
95*2abb3134SXin Li
96*2abb3134SXin Li  return(function(joint_conditional, max_em_iters = 1000,
97*2abb3134SXin Li                  epsilon = 10 ^ -6, verbose = FALSE,
98*2abb3134SXin Li                  estimate_var = FALSE) {
99*2abb3134SXin Li    matrix_dims <- dim(joint_conditional[[1]])
100*2abb3134SXin Li    # Check that number of dimensions is 2.
101*2abb3134SXin Li    if (length(matrix_dims) != 2) {
102*2abb3134SXin Li      Log('FATAL: Expected 2 dimensions, got %d', length(matrix_dims))
103*2abb3134SXin Li      stop()
104*2abb3134SXin Li    }
105*2abb3134SXin Li
106*2abb3134SXin Li    entry_size <- prod(matrix_dims)
107*2abb3134SXin Li    Log('entry size: %d', entry_size)
108*2abb3134SXin Li
109*2abb3134SXin Li    .SanityChecks(joint_conditional)
110*2abb3134SXin Li
111*2abb3134SXin Li    input_path <- file.path(tmp_dir, 'list_of_matrices.bin')
112*2abb3134SXin Li    Log("Writing flattened list of matrices to %s", input_path)
113*2abb3134SXin Li    f <- file(input_path, 'wb')  # binary file
114*2abb3134SXin Li    .WriteListOfMatrices(joint_conditional, f)
115*2abb3134SXin Li    close(f)
116*2abb3134SXin Li    Log("Done writing %s", input_path)
117*2abb3134SXin Li
118*2abb3134SXin Li    output_path <- file.path(tmp_dir, 'pij.bin')
119*2abb3134SXin Li
120*2abb3134SXin Li    cmd <- sprintf("%s %s %s %s", em_executable, input_path, output_path,
121*2abb3134SXin Li                   max_em_iters)
122*2abb3134SXin Li
123*2abb3134SXin Li    Log("Shell command: %s", cmd)
124*2abb3134SXin Li    exit_code <- system(cmd)
125*2abb3134SXin Li
126*2abb3134SXin Li    Log("Done running shell command")
127*2abb3134SXin Li    if (exit_code != 0) {
128*2abb3134SXin Li      stop(sprintf("Command failed with code %d", exit_code))
129*2abb3134SXin Li    }
130*2abb3134SXin Li
131*2abb3134SXin Li    f <- file(output_path, 'rb')
132*2abb3134SXin Li    result <- .ReadResult(f, entry_size, matrix_dims)
133*2abb3134SXin Li    close(f)
134*2abb3134SXin Li
135*2abb3134SXin Li    result
136*2abb3134SXin Li  })
137*2abb3134SXin Li}
138