1# fast_em.R: Wrapper around analysis/cpp/fast_em.cc. 2# 3# This serializes the input, shells out, and deserializes the output. 4 5.Flatten <- function(list_of_matrices) { 6 listOfVectors <- lapply(list_of_matrices, as.vector) 7 #print(listOfVectors) 8 9 # unlist takes list to vector. 10 unlist(listOfVectors) 11} 12 13.WriteListOfMatrices <- function(list_of_matrices, f) { 14 flattened <- .Flatten(list_of_matrices) 15 16 # NOTE: UpdateJointConditional does outer product of dimensions! 17 18 # 3 letter strings are null terminated 19 writeBin('ne ', con = f) 20 num_entries <- length(list_of_matrices) 21 writeBin(num_entries, con = f) 22 23 Log('Wrote num_entries = %d', num_entries) 24 25 # For 2x3, this is 6 26 writeBin('es ', con = f) 27 28 entry_size <- as.integer(prod(dim(list_of_matrices[[1]]))) 29 writeBin(entry_size, con = f) 30 31 Log('Wrote entry_size = %d', entry_size) 32 33 # now write the data 34 writeBin('dat', con = f) 35 writeBin(flattened, con = f) 36} 37 38.ExpectTag <- function(f, tag) { 39 # Read a single NUL-terminated character string. 40 actual <- readBin(con = f, what = "char", n = 1) 41 42 # Assert that we got what was expected. 43 if (length(actual) != 1) { 44 stop(sprintf("Failed to read a tag '%s'", tag)) 45 } 46 if (actual != tag) { 47 stop(sprintf("Expected '%s', got '%s'", tag, actual)) 48 } 49} 50 51.ReadResult <- function (f, entry_size, matrix_dims) { 52 .ExpectTag(f, "emi") 53 # NOTE: assuming R integers are 4 bytes (uint32_t) 54 num_em_iters <- readBin(con = f, what = "int", n = 1) 55 56 .ExpectTag(f, "pij") 57 pij <- readBin(con = f, what = "double", n = entry_size) 58 59 # Adjust dimensions 60 dim(pij) <- matrix_dims 61 62 Log("Number of EM iterations: %d", num_em_iters) 63 Log("PIJ read from external implementation:") 64 print(pij) 65 66 # est, sd, var_cov, hist 67 list(est = pij, num_em_iters = num_em_iters) 68} 69 70.SanityChecks <- function(joint_conditional) { 71 # Display some stats before sending it over to C++. 72 73 inf_counts <- lapply(joint_conditional, function(m) { 74 sum(m == Inf) 75 }) 76 total_inf <- sum(as.numeric(inf_counts)) 77 78 nan_counts <- lapply(joint_conditional, function(m) { 79 sum(is.nan(m)) 80 }) 81 total_nan <- sum(as.numeric(nan_counts)) 82 83 zero_counts <- lapply(joint_conditional, function(m) { 84 sum(m == 0.0) 85 }) 86 total_zero <- sum(as.numeric(zero_counts)) 87 88 #sum(joint_conditional[joint_conditional == Inf, ]) 89 Log('total inf: %s', total_inf) 90 Log('total nan: %s', total_nan) 91 Log('total zero: %s', total_zero) 92} 93 94ConstructFastEM <- function(em_executable, tmp_dir) { 95 96 return(function(joint_conditional, max_em_iters = 1000, 97 epsilon = 10 ^ -6, verbose = FALSE, 98 estimate_var = FALSE) { 99 matrix_dims <- dim(joint_conditional[[1]]) 100 # Check that number of dimensions is 2. 101 if (length(matrix_dims) != 2) { 102 Log('FATAL: Expected 2 dimensions, got %d', length(matrix_dims)) 103 stop() 104 } 105 106 entry_size <- prod(matrix_dims) 107 Log('entry size: %d', entry_size) 108 109 .SanityChecks(joint_conditional) 110 111 input_path <- file.path(tmp_dir, 'list_of_matrices.bin') 112 Log("Writing flattened list of matrices to %s", input_path) 113 f <- file(input_path, 'wb') # binary file 114 .WriteListOfMatrices(joint_conditional, f) 115 close(f) 116 Log("Done writing %s", input_path) 117 118 output_path <- file.path(tmp_dir, 'pij.bin') 119 120 cmd <- sprintf("%s %s %s %s", em_executable, input_path, output_path, 121 max_em_iters) 122 123 Log("Shell command: %s", cmd) 124 exit_code <- system(cmd) 125 126 Log("Done running shell command") 127 if (exit_code != 0) { 128 stop(sprintf("Command failed with code %d", exit_code)) 129 } 130 131 f <- file(output_path, 'rb') 132 result <- .ReadResult(f, entry_size, matrix_dims) 133 close(f) 134 135 result 136 }) 137} 138