// Copyright 2022 Google LLC // // This source code is licensed under the BSD-style license found in the // LICENSE file in the root directory of this source tree. #include #include #include #include #include #include #include #include #include #include #include #include #include template class UnaryTest : public ::testing::Test { protected: UnaryTest() { random_device = std::unique_ptr(new std::random_device()); rng = std::mt19937((*random_device)()); shape_dist = std::uniform_int_distribution(min_dim, XNN_MAX_TENSOR_DIMS); dim_dist = std::uniform_int_distribution(1, 9); i8dist = std::uniform_int_distribution(std::numeric_limits::min(), std::numeric_limits::max()); u8dist = std::uniform_int_distribution(std::numeric_limits::min(), std::numeric_limits::max()); u32dist = std::uniform_int_distribution(); scale_dist = std::uniform_real_distribution(0.1f, 10.0f); f32dist = std::uniform_real_distribution(0.01f, 1.0f); dims = RandomShape(); channels = dims.empty() ? 1 : dims.back(); xnn_shape shape = { .num_dims = dims.size(), }; memcpy(shape.dim, dims.data(), dims.size() * sizeof(size_t)); batch_size = xnn_shape_multiply_non_channel_dims(&shape); num_output_elements = batch_size * channels; scale = scale_dist(rng); signed_zero_point = i8dist(rng); unsigned_zero_point = u8dist(rng); input = std::vector(num_output_elements + XNN_EXTRA_BYTES / sizeof(InputType)); operator_output = std::vector(num_output_elements); subgraph_output = std::vector(num_output_elements); } std::vector RandomShape() { std::vector dims(shape_dist(rng)); std::generate(dims.begin(), dims.end(), [&] { return dim_dist(rng); }); return dims; } size_t NumElements(std::vector& dims) { return std::accumulate(dims.begin(), dims.end(), size_t(1), std::multiplies()); } std::unique_ptr random_device; std::mt19937 rng; std::uniform_int_distribution shape_dist; std::uniform_int_distribution dim_dist; std::uniform_real_distribution scale_dist; std::uniform_int_distribution i8dist; std::uniform_int_distribution u8dist; std::uniform_int_distribution u32dist; std::uniform_real_distribution f32dist; std::vector dims; uint32_t input_id; uint32_t output_id; size_t channels; size_t batch_size; size_t num_output_elements; float scale; int32_t signed_zero_point; int32_t unsigned_zero_point; std::vector input; std::vector operator_output; std::vector subgraph_output; };