xref: /aosp_15_r20/external/rappor/analysis/R/fast_em.R (revision 2abb31345f6c95944768b5222a9a5ed3fc68cc00)
1# fast_em.R: Wrapper around analysis/cpp/fast_em.cc.
2#
3# This serializes the input, shells out, and deserializes the output.
4
5.Flatten <- function(list_of_matrices) {
6  listOfVectors <- lapply(list_of_matrices, as.vector)
7  #print(listOfVectors)
8
9  # unlist takes list to vector.
10  unlist(listOfVectors)
11}
12
13.WriteListOfMatrices <- function(list_of_matrices, f) {
14  flattened <- .Flatten(list_of_matrices)
15
16  # NOTE: UpdateJointConditional does outer product of dimensions!
17
18  # 3 letter strings are null terminated
19  writeBin('ne ', con = f)
20  num_entries <- length(list_of_matrices)
21  writeBin(num_entries, con = f)
22
23  Log('Wrote num_entries = %d', num_entries)
24
25  # For 2x3, this is 6
26  writeBin('es ', con = f)
27
28  entry_size <- as.integer(prod(dim(list_of_matrices[[1]])))
29  writeBin(entry_size, con = f)
30
31  Log('Wrote entry_size = %d', entry_size)
32
33  # now write the data
34  writeBin('dat', con = f)
35  writeBin(flattened, con = f)
36}
37
38.ExpectTag <- function(f, tag) {
39  # Read a single NUL-terminated character string.
40  actual <- readBin(con = f, what = "char", n = 1)
41
42  # Assert that we got what was expected.
43  if (length(actual) != 1) {
44    stop(sprintf("Failed to read a tag '%s'", tag))
45  }
46  if (actual != tag) {
47    stop(sprintf("Expected '%s', got '%s'", tag, actual))
48  }
49}
50
51.ReadResult <- function (f, entry_size, matrix_dims) {
52  .ExpectTag(f, "emi")
53  # NOTE: assuming R integers are 4 bytes (uint32_t)
54  num_em_iters <- readBin(con = f, what = "int", n = 1)
55
56  .ExpectTag(f, "pij")
57  pij <- readBin(con = f, what = "double", n = entry_size)
58
59  # Adjust dimensions
60  dim(pij) <- matrix_dims
61
62  Log("Number of EM iterations: %d", num_em_iters)
63  Log("PIJ read from external implementation:")
64  print(pij)
65
66  # est, sd, var_cov, hist
67  list(est = pij, num_em_iters = num_em_iters)
68}
69
70.SanityChecks <- function(joint_conditional) {
71  # Display some stats before sending it over to C++.
72
73  inf_counts <- lapply(joint_conditional, function(m) {
74    sum(m == Inf)
75  })
76  total_inf <- sum(as.numeric(inf_counts))
77
78  nan_counts <- lapply(joint_conditional, function(m) {
79    sum(is.nan(m))
80  })
81  total_nan <- sum(as.numeric(nan_counts))
82
83  zero_counts <- lapply(joint_conditional, function(m) {
84    sum(m == 0.0)
85  })
86  total_zero <- sum(as.numeric(zero_counts))
87
88  #sum(joint_conditional[joint_conditional == Inf, ])
89  Log('total inf: %s', total_inf)
90  Log('total nan: %s', total_nan)
91  Log('total zero: %s', total_zero)
92}
93
94ConstructFastEM <- function(em_executable, tmp_dir) {
95
96  return(function(joint_conditional, max_em_iters = 1000,
97                  epsilon = 10 ^ -6, verbose = FALSE,
98                  estimate_var = FALSE) {
99    matrix_dims <- dim(joint_conditional[[1]])
100    # Check that number of dimensions is 2.
101    if (length(matrix_dims) != 2) {
102      Log('FATAL: Expected 2 dimensions, got %d', length(matrix_dims))
103      stop()
104    }
105
106    entry_size <- prod(matrix_dims)
107    Log('entry size: %d', entry_size)
108
109    .SanityChecks(joint_conditional)
110
111    input_path <- file.path(tmp_dir, 'list_of_matrices.bin')
112    Log("Writing flattened list of matrices to %s", input_path)
113    f <- file(input_path, 'wb')  # binary file
114    .WriteListOfMatrices(joint_conditional, f)
115    close(f)
116    Log("Done writing %s", input_path)
117
118    output_path <- file.path(tmp_dir, 'pij.bin')
119
120    cmd <- sprintf("%s %s %s %s", em_executable, input_path, output_path,
121                   max_em_iters)
122
123    Log("Shell command: %s", cmd)
124    exit_code <- system(cmd)
125
126    Log("Done running shell command")
127    if (exit_code != 0) {
128      stop(sprintf("Command failed with code %d", exit_code))
129    }
130
131    f <- file(output_path, 'rb')
132    result <- .ReadResult(f, entry_size, matrix_dims)
133    close(f)
134
135    result
136  })
137}
138