xref: /aosp_15_r20/external/pytorch/aten/src/ATen/test/cpu_generator_test.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 
3 #include <ATen/ATen.h>
4 #include <ATen/Utils.h>
5 #include <ATen/CPUGeneratorImpl.h>
6 #include <ATen/core/PhiloxRNGEngine.h>
7 #include <c10/util/irange.h>
8 #include <thread>
9 #include <limits>
10 #include <random>
11 
12 using namespace at;
13 
TEST(CPUGeneratorImpl,TestGeneratorDynamicCast)14 TEST(CPUGeneratorImpl, TestGeneratorDynamicCast) {
15   // Test Description: Check dynamic cast for CPU
16   auto foo = at::detail::createCPUGenerator();
17   auto result = check_generator<CPUGeneratorImpl>(foo);
18   ASSERT_EQ(typeid(CPUGeneratorImpl*).hash_code(), typeid(result).hash_code());
19 }
20 
TEST(CPUGeneratorImpl,TestDefaultGenerator)21 TEST(CPUGeneratorImpl, TestDefaultGenerator) {
22   // Test Description:
23   // Check if default generator is created only once
24   // address of generator should be same in all calls
25   auto foo = at::detail::getDefaultCPUGenerator();
26   // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
27   auto bar = at::detail::getDefaultCPUGenerator();
28   ASSERT_EQ(foo, bar);
29 }
30 
TEST(CPUGeneratorImpl,TestCloning)31 TEST(CPUGeneratorImpl, TestCloning) {
32   // Test Description:
33   // Check cloning of new generators.
34   // Note that we don't allow cloning of other
35   // generator states into default generators.
36   auto gen1 = at::detail::createCPUGenerator();
37   auto cpu_gen1 = check_generator<CPUGeneratorImpl>(gen1);
38   // NOLINTNEXTLINE(clang-analyzer-security.insecureAPI.rand)
39   cpu_gen1->random(); // advance gen1 state
40   // NOLINTNEXTLINE(clang-analyzer-security.insecureAPI.rand)
41   cpu_gen1->random();
42   auto gen2 = at::detail::createCPUGenerator();
43   gen2 = gen1.clone();
44   auto cpu_gen2 = check_generator<CPUGeneratorImpl>(gen2);
45   // NOLINTNEXTLINE(clang-analyzer-security.insecureAPI.rand)
46   ASSERT_EQ(cpu_gen1->random(), cpu_gen2->random());
47 }
48 
thread_func_get_engine_op(CPUGeneratorImpl * generator)49 void thread_func_get_engine_op(CPUGeneratorImpl* generator) {
50   std::lock_guard<std::mutex> lock(generator->mutex_);
51   // NOLINTNEXTLINE(clang-analyzer-security.insecureAPI.rand)
52   generator->random();
53 }
54 
TEST(CPUGeneratorImpl,TestMultithreadingGetEngineOperator)55 TEST(CPUGeneratorImpl, TestMultithreadingGetEngineOperator) {
56   // Test Description:
57   // Check CPUGeneratorImpl is reentrant and the engine state
58   // is not corrupted when multiple threads request for
59   // random samples.
60   // See Note [Acquire lock when using random generators]
61   auto gen1 = at::detail::createCPUGenerator();
62   auto cpu_gen1 = check_generator<CPUGeneratorImpl>(gen1);
63   auto gen2 = at::detail::createCPUGenerator();
64   {
65     std::lock_guard<std::mutex> lock(gen1.mutex());
66     gen2 = gen1.clone(); // capture the current state of default generator
67   }
68   std::thread t0{thread_func_get_engine_op, cpu_gen1};
69   std::thread t1{thread_func_get_engine_op, cpu_gen1};
70   std::thread t2{thread_func_get_engine_op, cpu_gen1};
71   t0.join();
72   t1.join();
73   t2.join();
74   std::lock_guard<std::mutex> lock(gen2.mutex());
75   auto cpu_gen2 = check_generator<CPUGeneratorImpl>(gen2);
76   // NOLINTNEXTLINE(clang-analyzer-security.insecureAPI.rand)
77   cpu_gen2->random();
78   // NOLINTNEXTLINE(clang-analyzer-security.insecureAPI.rand)
79   cpu_gen2->random();
80   // NOLINTNEXTLINE(clang-analyzer-security.insecureAPI.rand)
81   cpu_gen2->random();
82   // NOLINTNEXTLINE(clang-analyzer-security.insecureAPI.rand)
83   ASSERT_EQ(cpu_gen1->random(), cpu_gen2->random());
84 }
85 
TEST(CPUGeneratorImpl,TestGetSetCurrentSeed)86 TEST(CPUGeneratorImpl, TestGetSetCurrentSeed) {
87   // Test Description:
88   // Test current seed getter and setter
89   // See Note [Acquire lock when using random generators]
90   auto foo = at::detail::getDefaultCPUGenerator();
91   std::lock_guard<std::mutex> lock(foo.mutex());
92   foo.set_current_seed(123);
93   auto current_seed = foo.current_seed();
94   ASSERT_EQ(current_seed, 123);
95 }
96 
thread_func_get_set_current_seed(Generator generator)97 void thread_func_get_set_current_seed(Generator generator) {
98   std::lock_guard<std::mutex> lock(generator.mutex());
99   auto current_seed = generator.current_seed();
100   current_seed++;
101   generator.set_current_seed(current_seed);
102 }
103 
TEST(CPUGeneratorImpl,TestMultithreadingGetSetCurrentSeed)104 TEST(CPUGeneratorImpl, TestMultithreadingGetSetCurrentSeed) {
105   // Test Description:
106   // Test current seed getter and setter are thread safe
107   // See Note [Acquire lock when using random generators]
108   auto gen1 = at::detail::getDefaultCPUGenerator();
109   auto initial_seed = gen1.current_seed();
110   std::thread t0{thread_func_get_set_current_seed, gen1};
111   std::thread t1{thread_func_get_set_current_seed, gen1};
112   std::thread t2{thread_func_get_set_current_seed, gen1};
113   t0.join();
114   t1.join();
115   t2.join();
116   ASSERT_EQ(gen1.current_seed(), initial_seed+3);
117 }
118 
TEST(CPUGeneratorImpl,TestRNGForking)119 TEST(CPUGeneratorImpl, TestRNGForking) {
120   // Test Description:
121   // Test that state of a generator can be frozen and
122   // restored
123   // See Note [Acquire lock when using random generators]
124   auto default_gen = at::detail::getDefaultCPUGenerator();
125   auto current_gen = at::detail::createCPUGenerator();
126   {
127     std::lock_guard<std::mutex> lock(default_gen.mutex());
128     current_gen = default_gen.clone(); // capture the current state of default generator
129   }
130   auto target_value = at::randn({1000});
131   // Dramatically alter the internal state of the main generator
132   auto x = at::randn({100000});
133   auto forked_value = at::randn({1000}, current_gen);
134   ASSERT_EQ(target_value.sum().item<double>(), forked_value.sum().item<double>());
135 }
136 
137 /**
138  * Philox CPU Engine Tests
139  */
140 
TEST(CPUGeneratorImpl,TestPhiloxEngineReproducibility)141 TEST(CPUGeneratorImpl, TestPhiloxEngineReproducibility) {
142   // Test Description:
143   //   Tests if same inputs give same results.
144   //   launch on same thread index and create two engines.
145   //   Given same seed, idx and offset, assert that the engines
146   //   should be aligned and have the same sequence.
147   at::Philox4_32 engine1(0, 0, 4);
148   at::Philox4_32 engine2(0, 0, 4);
149   ASSERT_EQ(engine1(), engine2());
150 }
151 
TEST(CPUGeneratorImpl,TestPhiloxEngineOffset1)152 TEST(CPUGeneratorImpl, TestPhiloxEngineOffset1) {
153   // Test Description:
154   //   Tests offsetting in same thread index.
155   //   make one engine skip the first 8 values and
156   //   make another engine increment to until the
157   //   first 8 values. Assert that the first call
158   //   of engine2 and the 9th call of engine1 are equal.
159   at::Philox4_32 engine1(123, 1, 0);
160   // Note: offset is a multiple of 4.
161   // So if you want to skip 8 values, offset would
162   // be 2, since 2*4=8.
163   at::Philox4_32 engine2(123, 1, 2);
164   for (C10_UNUSED const auto i : c10::irange(8)) {
165     // Note: instead of using the engine() call 8 times
166     // we could have achieved the same functionality by
167     // calling the incr() function twice.
168     engine1();
169   }
170   ASSERT_EQ(engine1(), engine2());
171 }
172 
TEST(CPUGeneratorImpl,TestPhiloxEngineOffset2)173 TEST(CPUGeneratorImpl, TestPhiloxEngineOffset2) {
174   // Test Description:
175   //   Tests edge case at the end of the 2^190th value of the generator.
176   //   launch on same thread index and create two engines.
177   //   make engine1 skip to the 2^64th 128 bit while being at thread 0
178   //   make engine2 skip to the 2^64th 128 bit while being at 2^64th thread
179   //   Assert that engine2 should be increment_val+1 steps behind engine1.
180   unsigned long long increment_val = std::numeric_limits<uint64_t>::max();
181   at::Philox4_32 engine1(123, 0, increment_val);
182   at::Philox4_32 engine2(123, increment_val, increment_val);
183 
184   engine2.incr_n(increment_val);
185   engine2.incr();
186   ASSERT_EQ(engine1(), engine2());
187 }
188 
TEST(CPUGeneratorImpl,TestPhiloxEngineOffset3)189 TEST(CPUGeneratorImpl, TestPhiloxEngineOffset3) {
190   // Test Description:
191   //   Tests edge case in between thread indices.
192   //   launch on same thread index and create two engines.
193   //   make engine1 skip to the 2^64th 128 bit while being at thread 0
194   //   start engine2 at thread 1, with offset 0
195   //   Assert that engine1 is 1 step behind engine2.
196   unsigned long long increment_val = std::numeric_limits<uint64_t>::max();
197   at::Philox4_32 engine1(123, 0, increment_val);
198   at::Philox4_32 engine2(123, 1, 0);
199   engine1.incr();
200   ASSERT_EQ(engine1(), engine2());
201 }
202 
TEST(CPUGeneratorImpl,TestPhiloxEngineIndex)203 TEST(CPUGeneratorImpl, TestPhiloxEngineIndex) {
204   // Test Description:
205   //   Tests if thread indexing is working properly.
206   //   create two engines with different thread index but same offset.
207   //   Assert that the engines have different sequences.
208   at::Philox4_32 engine1(123456, 0, 4);
209   at::Philox4_32 engine2(123456, 1, 4);
210   ASSERT_NE(engine1(), engine2());
211 }
212 
213 /**
214  * MT19937 CPU Engine Tests
215  */
216 
TEST(CPUGeneratorImpl,TestMT19937EngineReproducibility)217 TEST(CPUGeneratorImpl, TestMT19937EngineReproducibility) {
218   // Test Description:
219   //   Tests if same inputs give same results when compared
220   //   to std.
221 
222   // test with zero seed
223   at::mt19937 engine1(0);
224   std::mt19937 engine2(0);
225   for (C10_UNUSED const auto i : c10::irange(10000)) {
226     ASSERT_EQ(engine1(), engine2());
227   }
228 
229   // test with large seed
230   engine1 = at::mt19937(2147483647);
231   engine2 = std::mt19937(2147483647);
232   for (C10_UNUSED const auto i : c10::irange(10000)) {
233     ASSERT_EQ(engine1(), engine2());
234   }
235 
236   // test with random seed
237   std::random_device rd;
238   auto seed = rd();
239   engine1 = at::mt19937(seed);
240   engine2 = std::mt19937(seed);
241   for (C10_UNUSED const auto i : c10::irange(10000)) {
242     ASSERT_EQ(engine1(), engine2());
243   }
244 
245 }
246 
TEST(CPUGeneratorImpl,TestPhiloxEngineReproducibilityRandN)247 TEST(CPUGeneratorImpl, TestPhiloxEngineReproducibilityRandN) {
248   at::Philox4_32 engine1(0, 0, 4);
249   at::Philox4_32 engine2(0, 0, 4);
250   ASSERT_EQ(engine1.randn(1), engine2.randn(1));
251 }
252 
TEST(CPUGeneratorImpl,TestPhiloxEngineSeedRandN)253 TEST(CPUGeneratorImpl, TestPhiloxEngineSeedRandN) {
254   at::Philox4_32 engine1(0);
255   at::Philox4_32 engine2(123456);
256   ASSERT_NE(engine1.randn(1), engine2.randn(1));
257 }
258 
TEST(CPUGeneratorImpl,TestPhiloxDeterministic)259 TEST(CPUGeneratorImpl, TestPhiloxDeterministic) {
260   at::Philox4_32 engine1(0, 0, 4);
261   ASSERT_EQ(engine1(), 4013802324);  // Determinism!
262   ASSERT_EQ(engine1(), 2979262830);  // Determinism!
263 
264   at::Philox4_32 engine2(10, 0, 1);
265   ASSERT_EQ(engine2(), 2007330488);  // Determinism!
266   ASSERT_EQ(engine2(), 2354548925);  // Determinism!
267 }
268