1 #include <gtest/gtest.h>
2
3 #include <ATen/ATen.h>
4 #include <ATen/Utils.h>
5 #include <ATen/CPUGeneratorImpl.h>
6 #include <ATen/core/PhiloxRNGEngine.h>
7 #include <c10/util/irange.h>
8 #include <thread>
9 #include <limits>
10 #include <random>
11
12 using namespace at;
13
TEST(CPUGeneratorImpl,TestGeneratorDynamicCast)14 TEST(CPUGeneratorImpl, TestGeneratorDynamicCast) {
15 // Test Description: Check dynamic cast for CPU
16 auto foo = at::detail::createCPUGenerator();
17 auto result = check_generator<CPUGeneratorImpl>(foo);
18 ASSERT_EQ(typeid(CPUGeneratorImpl*).hash_code(), typeid(result).hash_code());
19 }
20
TEST(CPUGeneratorImpl,TestDefaultGenerator)21 TEST(CPUGeneratorImpl, TestDefaultGenerator) {
22 // Test Description:
23 // Check if default generator is created only once
24 // address of generator should be same in all calls
25 auto foo = at::detail::getDefaultCPUGenerator();
26 // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
27 auto bar = at::detail::getDefaultCPUGenerator();
28 ASSERT_EQ(foo, bar);
29 }
30
TEST(CPUGeneratorImpl,TestCloning)31 TEST(CPUGeneratorImpl, TestCloning) {
32 // Test Description:
33 // Check cloning of new generators.
34 // Note that we don't allow cloning of other
35 // generator states into default generators.
36 auto gen1 = at::detail::createCPUGenerator();
37 auto cpu_gen1 = check_generator<CPUGeneratorImpl>(gen1);
38 // NOLINTNEXTLINE(clang-analyzer-security.insecureAPI.rand)
39 cpu_gen1->random(); // advance gen1 state
40 // NOLINTNEXTLINE(clang-analyzer-security.insecureAPI.rand)
41 cpu_gen1->random();
42 auto gen2 = at::detail::createCPUGenerator();
43 gen2 = gen1.clone();
44 auto cpu_gen2 = check_generator<CPUGeneratorImpl>(gen2);
45 // NOLINTNEXTLINE(clang-analyzer-security.insecureAPI.rand)
46 ASSERT_EQ(cpu_gen1->random(), cpu_gen2->random());
47 }
48
thread_func_get_engine_op(CPUGeneratorImpl * generator)49 void thread_func_get_engine_op(CPUGeneratorImpl* generator) {
50 std::lock_guard<std::mutex> lock(generator->mutex_);
51 // NOLINTNEXTLINE(clang-analyzer-security.insecureAPI.rand)
52 generator->random();
53 }
54
TEST(CPUGeneratorImpl,TestMultithreadingGetEngineOperator)55 TEST(CPUGeneratorImpl, TestMultithreadingGetEngineOperator) {
56 // Test Description:
57 // Check CPUGeneratorImpl is reentrant and the engine state
58 // is not corrupted when multiple threads request for
59 // random samples.
60 // See Note [Acquire lock when using random generators]
61 auto gen1 = at::detail::createCPUGenerator();
62 auto cpu_gen1 = check_generator<CPUGeneratorImpl>(gen1);
63 auto gen2 = at::detail::createCPUGenerator();
64 {
65 std::lock_guard<std::mutex> lock(gen1.mutex());
66 gen2 = gen1.clone(); // capture the current state of default generator
67 }
68 std::thread t0{thread_func_get_engine_op, cpu_gen1};
69 std::thread t1{thread_func_get_engine_op, cpu_gen1};
70 std::thread t2{thread_func_get_engine_op, cpu_gen1};
71 t0.join();
72 t1.join();
73 t2.join();
74 std::lock_guard<std::mutex> lock(gen2.mutex());
75 auto cpu_gen2 = check_generator<CPUGeneratorImpl>(gen2);
76 // NOLINTNEXTLINE(clang-analyzer-security.insecureAPI.rand)
77 cpu_gen2->random();
78 // NOLINTNEXTLINE(clang-analyzer-security.insecureAPI.rand)
79 cpu_gen2->random();
80 // NOLINTNEXTLINE(clang-analyzer-security.insecureAPI.rand)
81 cpu_gen2->random();
82 // NOLINTNEXTLINE(clang-analyzer-security.insecureAPI.rand)
83 ASSERT_EQ(cpu_gen1->random(), cpu_gen2->random());
84 }
85
TEST(CPUGeneratorImpl,TestGetSetCurrentSeed)86 TEST(CPUGeneratorImpl, TestGetSetCurrentSeed) {
87 // Test Description:
88 // Test current seed getter and setter
89 // See Note [Acquire lock when using random generators]
90 auto foo = at::detail::getDefaultCPUGenerator();
91 std::lock_guard<std::mutex> lock(foo.mutex());
92 foo.set_current_seed(123);
93 auto current_seed = foo.current_seed();
94 ASSERT_EQ(current_seed, 123);
95 }
96
thread_func_get_set_current_seed(Generator generator)97 void thread_func_get_set_current_seed(Generator generator) {
98 std::lock_guard<std::mutex> lock(generator.mutex());
99 auto current_seed = generator.current_seed();
100 current_seed++;
101 generator.set_current_seed(current_seed);
102 }
103
TEST(CPUGeneratorImpl,TestMultithreadingGetSetCurrentSeed)104 TEST(CPUGeneratorImpl, TestMultithreadingGetSetCurrentSeed) {
105 // Test Description:
106 // Test current seed getter and setter are thread safe
107 // See Note [Acquire lock when using random generators]
108 auto gen1 = at::detail::getDefaultCPUGenerator();
109 auto initial_seed = gen1.current_seed();
110 std::thread t0{thread_func_get_set_current_seed, gen1};
111 std::thread t1{thread_func_get_set_current_seed, gen1};
112 std::thread t2{thread_func_get_set_current_seed, gen1};
113 t0.join();
114 t1.join();
115 t2.join();
116 ASSERT_EQ(gen1.current_seed(), initial_seed+3);
117 }
118
TEST(CPUGeneratorImpl,TestRNGForking)119 TEST(CPUGeneratorImpl, TestRNGForking) {
120 // Test Description:
121 // Test that state of a generator can be frozen and
122 // restored
123 // See Note [Acquire lock when using random generators]
124 auto default_gen = at::detail::getDefaultCPUGenerator();
125 auto current_gen = at::detail::createCPUGenerator();
126 {
127 std::lock_guard<std::mutex> lock(default_gen.mutex());
128 current_gen = default_gen.clone(); // capture the current state of default generator
129 }
130 auto target_value = at::randn({1000});
131 // Dramatically alter the internal state of the main generator
132 auto x = at::randn({100000});
133 auto forked_value = at::randn({1000}, current_gen);
134 ASSERT_EQ(target_value.sum().item<double>(), forked_value.sum().item<double>());
135 }
136
137 /**
138 * Philox CPU Engine Tests
139 */
140
TEST(CPUGeneratorImpl,TestPhiloxEngineReproducibility)141 TEST(CPUGeneratorImpl, TestPhiloxEngineReproducibility) {
142 // Test Description:
143 // Tests if same inputs give same results.
144 // launch on same thread index and create two engines.
145 // Given same seed, idx and offset, assert that the engines
146 // should be aligned and have the same sequence.
147 at::Philox4_32 engine1(0, 0, 4);
148 at::Philox4_32 engine2(0, 0, 4);
149 ASSERT_EQ(engine1(), engine2());
150 }
151
TEST(CPUGeneratorImpl,TestPhiloxEngineOffset1)152 TEST(CPUGeneratorImpl, TestPhiloxEngineOffset1) {
153 // Test Description:
154 // Tests offsetting in same thread index.
155 // make one engine skip the first 8 values and
156 // make another engine increment to until the
157 // first 8 values. Assert that the first call
158 // of engine2 and the 9th call of engine1 are equal.
159 at::Philox4_32 engine1(123, 1, 0);
160 // Note: offset is a multiple of 4.
161 // So if you want to skip 8 values, offset would
162 // be 2, since 2*4=8.
163 at::Philox4_32 engine2(123, 1, 2);
164 for (C10_UNUSED const auto i : c10::irange(8)) {
165 // Note: instead of using the engine() call 8 times
166 // we could have achieved the same functionality by
167 // calling the incr() function twice.
168 engine1();
169 }
170 ASSERT_EQ(engine1(), engine2());
171 }
172
TEST(CPUGeneratorImpl,TestPhiloxEngineOffset2)173 TEST(CPUGeneratorImpl, TestPhiloxEngineOffset2) {
174 // Test Description:
175 // Tests edge case at the end of the 2^190th value of the generator.
176 // launch on same thread index and create two engines.
177 // make engine1 skip to the 2^64th 128 bit while being at thread 0
178 // make engine2 skip to the 2^64th 128 bit while being at 2^64th thread
179 // Assert that engine2 should be increment_val+1 steps behind engine1.
180 unsigned long long increment_val = std::numeric_limits<uint64_t>::max();
181 at::Philox4_32 engine1(123, 0, increment_val);
182 at::Philox4_32 engine2(123, increment_val, increment_val);
183
184 engine2.incr_n(increment_val);
185 engine2.incr();
186 ASSERT_EQ(engine1(), engine2());
187 }
188
TEST(CPUGeneratorImpl,TestPhiloxEngineOffset3)189 TEST(CPUGeneratorImpl, TestPhiloxEngineOffset3) {
190 // Test Description:
191 // Tests edge case in between thread indices.
192 // launch on same thread index and create two engines.
193 // make engine1 skip to the 2^64th 128 bit while being at thread 0
194 // start engine2 at thread 1, with offset 0
195 // Assert that engine1 is 1 step behind engine2.
196 unsigned long long increment_val = std::numeric_limits<uint64_t>::max();
197 at::Philox4_32 engine1(123, 0, increment_val);
198 at::Philox4_32 engine2(123, 1, 0);
199 engine1.incr();
200 ASSERT_EQ(engine1(), engine2());
201 }
202
TEST(CPUGeneratorImpl,TestPhiloxEngineIndex)203 TEST(CPUGeneratorImpl, TestPhiloxEngineIndex) {
204 // Test Description:
205 // Tests if thread indexing is working properly.
206 // create two engines with different thread index but same offset.
207 // Assert that the engines have different sequences.
208 at::Philox4_32 engine1(123456, 0, 4);
209 at::Philox4_32 engine2(123456, 1, 4);
210 ASSERT_NE(engine1(), engine2());
211 }
212
213 /**
214 * MT19937 CPU Engine Tests
215 */
216
TEST(CPUGeneratorImpl,TestMT19937EngineReproducibility)217 TEST(CPUGeneratorImpl, TestMT19937EngineReproducibility) {
218 // Test Description:
219 // Tests if same inputs give same results when compared
220 // to std.
221
222 // test with zero seed
223 at::mt19937 engine1(0);
224 std::mt19937 engine2(0);
225 for (C10_UNUSED const auto i : c10::irange(10000)) {
226 ASSERT_EQ(engine1(), engine2());
227 }
228
229 // test with large seed
230 engine1 = at::mt19937(2147483647);
231 engine2 = std::mt19937(2147483647);
232 for (C10_UNUSED const auto i : c10::irange(10000)) {
233 ASSERT_EQ(engine1(), engine2());
234 }
235
236 // test with random seed
237 std::random_device rd;
238 auto seed = rd();
239 engine1 = at::mt19937(seed);
240 engine2 = std::mt19937(seed);
241 for (C10_UNUSED const auto i : c10::irange(10000)) {
242 ASSERT_EQ(engine1(), engine2());
243 }
244
245 }
246
TEST(CPUGeneratorImpl,TestPhiloxEngineReproducibilityRandN)247 TEST(CPUGeneratorImpl, TestPhiloxEngineReproducibilityRandN) {
248 at::Philox4_32 engine1(0, 0, 4);
249 at::Philox4_32 engine2(0, 0, 4);
250 ASSERT_EQ(engine1.randn(1), engine2.randn(1));
251 }
252
TEST(CPUGeneratorImpl,TestPhiloxEngineSeedRandN)253 TEST(CPUGeneratorImpl, TestPhiloxEngineSeedRandN) {
254 at::Philox4_32 engine1(0);
255 at::Philox4_32 engine2(123456);
256 ASSERT_NE(engine1.randn(1), engine2.randn(1));
257 }
258
TEST(CPUGeneratorImpl,TestPhiloxDeterministic)259 TEST(CPUGeneratorImpl, TestPhiloxDeterministic) {
260 at::Philox4_32 engine1(0, 0, 4);
261 ASSERT_EQ(engine1(), 4013802324); // Determinism!
262 ASSERT_EQ(engine1(), 2979262830); // Determinism!
263
264 at::Philox4_32 engine2(10, 0, 1);
265 ASSERT_EQ(engine2(), 2007330488); // Determinism!
266 ASSERT_EQ(engine2(), 2354548925); // Determinism!
267 }
268