1#!/usr/bin/env Rscript 2# 3# Copyright 2015 Google Inc. All rights reserved. 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16 17source('tests/gen_counts.R') 18 19# Usage: 20# 21# $ ./gen_true_values.R exp 100 10000 1 foo.csv 22# 23# Inputs: 24# distribution name 25# size of the distribution's support 26# number of clients 27# reports per client 28# name of the output file 29# Output: 30# csv file with reports sampled according to the specified distribution. 31 32GenerateTrueValues <- function(distr, distr_range, num_clients, 33 reports_per_client, num_cohorts) { 34 35 # Sums to 1.0, e.g. [0.2 0.2 0.2 0.2 0.2] for uniform distribution of 5. 36 pdf <- ComputePdf(distr, distr_range) 37 38 num_reports <- num_clients * reports_per_client 39 40 # Computes the number of clients reporting each value, where the numbers are 41 # sampled according to pdf. (sums to num_reports) 42 partition <- RandomPartition(num_reports, pdf) 43 44 value_ints <- rep(1:distr_range, partition) # expand partition 45 46 stopifnot(length(value_ints) == num_reports) 47 48 # Shuffle values randomly (may take a few sec for > 10^8 inputs) 49 value_ints <- sample(value_ints) 50 51 # Reported values are strings, so prefix integers "v". Even slower than 52 # shuffling. 53 values <- sprintf("v%d", value_ints) 54 55 # e.g. [1 1 2 2 3 3] if num_clients is 3 and reports_per_client is 2 56 client_ints <- rep(1:num_clients, each = reports_per_client) 57 58 # Cohorts are assigned to clients. Cohorts are 0-based. 59 cohorts <- client_ints %% num_cohorts # %% is integer modulus 60 61 clients <- sprintf("c%d", client_ints) 62 63 data.frame(client = clients, cohort = cohorts, value = values) 64} 65 66main <- function(argv) { 67 distr <- argv[[1]] 68 distr_range <- as.integer(argv[[2]]) 69 num_clients <- as.integer(argv[[3]]) 70 reports_per_client <- as.integer(argv[[4]]) 71 num_cohorts <- as.integer(argv[[5]]) 72 out_file <- argv[[6]] 73 74 reports <- GenerateTrueValues(distr, distr_range, num_clients, 75 reports_per_client, num_cohorts) 76 77 write.csv(reports, file = out_file, row.names = FALSE, quote = FALSE) 78} 79 80if (length(sys.frames()) == 0) { 81 main(commandArgs(TRUE)) 82} 83