xref: /aosp_15_r20/external/pytorch/c10/test/util/ThreadLocal_test.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/util/ThreadLocal.h>
2 #include <gtest/gtest.h>
3 
4 #include <atomic>
5 #include <thread>
6 
7 namespace {
8 
TEST(ThreadLocal,TestNoOpScopeWithOneVar)9 TEST(ThreadLocal, TestNoOpScopeWithOneVar) {
10   C10_DEFINE_TLS_static(std::string, str);
11 }
12 
TEST(ThreadLocalTest,TestNoOpScopeWithTwoVars)13 TEST(ThreadLocalTest, TestNoOpScopeWithTwoVars) {
14   C10_DEFINE_TLS_static(std::string, str);
15   C10_DEFINE_TLS_static(std::string, str2);
16 }
17 
TEST(ThreadLocalTest,TestScopeWithOneVar)18 TEST(ThreadLocalTest, TestScopeWithOneVar) {
19   C10_DEFINE_TLS_static(std::string, str);
20   EXPECT_EQ(*str, std::string());
21   EXPECT_EQ(*str, "");
22 
23   *str = "abc";
24   EXPECT_EQ(*str, "abc");
25   EXPECT_EQ(str->length(), 3);
26   EXPECT_EQ(str.get(), "abc");
27 }
28 
TEST(ThreadLocalTest,TestScopeWithTwoVars)29 TEST(ThreadLocalTest, TestScopeWithTwoVars) {
30   C10_DEFINE_TLS_static(std::string, str);
31   EXPECT_EQ(*str, "");
32 
33   C10_DEFINE_TLS_static(std::string, str2);
34 
35   *str = "abc";
36   EXPECT_EQ(*str, "abc");
37   EXPECT_EQ(*str2, "");
38 
39   *str2 = *str;
40   EXPECT_EQ(*str, "abc");
41   EXPECT_EQ(*str2, "abc");
42 
43   str->clear();
44   EXPECT_EQ(*str, "");
45   EXPECT_EQ(*str2, "abc");
46 }
47 
TEST(ThreadLocalTest,TestInnerScopeWithTwoVars)48 TEST(ThreadLocalTest, TestInnerScopeWithTwoVars) {
49   C10_DEFINE_TLS_static(std::string, str);
50   *str = "abc";
51 
52   {
53     C10_DEFINE_TLS_static(std::string, str2);
54     EXPECT_EQ(*str2, "");
55 
56     *str2 = *str;
57     EXPECT_EQ(*str, "abc");
58     EXPECT_EQ(*str2, "abc");
59 
60     str->clear();
61     EXPECT_EQ(*str2, "abc");
62   }
63 
64   EXPECT_EQ(*str, "");
65 }
66 
67 struct Foo {
68   C10_DECLARE_TLS_class_static(Foo, std::string, str_);
69 };
70 
71 C10_DEFINE_TLS_class_static(Foo, std::string, str_);
72 
TEST(ThreadLocalTest,TestClassScope)73 TEST(ThreadLocalTest, TestClassScope) {
74   EXPECT_EQ(*Foo::str_, "");
75 
76   *Foo::str_ = "abc";
77   EXPECT_EQ(*Foo::str_, "abc");
78   EXPECT_EQ(Foo::str_->length(), 3);
79   EXPECT_EQ(Foo::str_.get(), "abc");
80 }
81 
82 C10_DEFINE_TLS_static(std::string, global_);
83 C10_DEFINE_TLS_static(std::string, global2_);
TEST(ThreadLocalTest,TestTwoGlobalScopeVars)84 TEST(ThreadLocalTest, TestTwoGlobalScopeVars) {
85   EXPECT_EQ(*global_, "");
86   EXPECT_EQ(*global2_, "");
87 
88   *global_ = "abc";
89   EXPECT_EQ(global_->length(), 3);
90   EXPECT_EQ(*global_, "abc");
91   EXPECT_EQ(*global2_, "");
92 
93   *global2_ = *global_;
94   EXPECT_EQ(*global_, "abc");
95   EXPECT_EQ(*global2_, "abc");
96 
97   global_->clear();
98   EXPECT_EQ(*global_, "");
99   EXPECT_EQ(*global2_, "abc");
100   EXPECT_EQ(global2_.get(), "abc");
101 }
102 
103 C10_DEFINE_TLS_static(std::string, global3_);
TEST(ThreadLocalTest,TestGlobalWithLocalScopeVars)104 TEST(ThreadLocalTest, TestGlobalWithLocalScopeVars) {
105   *global3_ = "abc";
106 
107   C10_DEFINE_TLS_static(std::string, str);
108 
109   std::swap(*global3_, *str);
110   EXPECT_EQ(*str, "abc");
111   EXPECT_EQ(*global3_, "");
112 }
113 
TEST(ThreadLocalTest,TestThreadWithLocalScopeVar)114 TEST(ThreadLocalTest, TestThreadWithLocalScopeVar) {
115   C10_DEFINE_TLS_static(std::string, str);
116   *str = "abc";
117 
118   std::atomic_bool b(false);
119   std::thread t([&b]() {
120     EXPECT_EQ(*str, "");
121     *str = "def";
122     b = true;
123     EXPECT_EQ(*str, "def");
124   });
125   t.join();
126 
127   EXPECT_TRUE(b);
128   EXPECT_EQ(*str, "abc");
129 }
130 
131 C10_DEFINE_TLS_static(std::string, global4_);
TEST(ThreadLocalTest,TestThreadWithGlobalScopeVar)132 TEST(ThreadLocalTest, TestThreadWithGlobalScopeVar) {
133   *global4_ = "abc";
134 
135   std::atomic_bool b(false);
136   std::thread t([&b]() {
137     EXPECT_EQ(*global4_, "");
138     *global4_ = "def";
139     b = true;
140     EXPECT_EQ(*global4_, "def");
141   });
142   t.join();
143 
144   EXPECT_TRUE(b);
145   EXPECT_EQ(*global4_, "abc");
146 }
147 
TEST(ThreadLocalTest,TestObjectsAreReleased)148 TEST(ThreadLocalTest, TestObjectsAreReleased) {
149   static std::atomic<int> ctors{0};
150   static std::atomic<int> dtors{0};
151   struct A {
152     A() : i() {
153       ++ctors;
154     }
155 
156     ~A() {
157       ++dtors;
158     }
159 
160     A(const A&) = delete;
161     A& operator=(const A&) = delete;
162 
163     int i;
164   };
165 
166   C10_DEFINE_TLS_static(A, a);
167 
168   std::atomic_bool b(false);
169   std::thread t([&b]() {
170     EXPECT_EQ(a->i, 0);
171     a->i = 1;
172     EXPECT_EQ(a->i, 1);
173     b = true;
174   });
175   t.join();
176 
177   EXPECT_TRUE(b);
178 
179   EXPECT_EQ(ctors, 1);
180   EXPECT_EQ(dtors, 1);
181 }
182 
TEST(ThreadLocalTest,TestObjectsAreReleasedByNonstaticThreadLocal)183 TEST(ThreadLocalTest, TestObjectsAreReleasedByNonstaticThreadLocal) {
184   static std::atomic<int> ctors(0);
185   static std::atomic<int> dtors(0);
186   struct A {
187     A() : i() {
188       ++ctors;
189     }
190 
191     ~A() {
192       ++dtors;
193     }
194 
195     A(const A&) = delete;
196     A& operator=(const A&) = delete;
197 
198     int i;
199   };
200 
201   std::atomic_bool b(false);
202   std::thread t([&b]() {
203 #if defined(C10_PREFER_CUSTOM_THREAD_LOCAL_STORAGE)
204     ::c10::ThreadLocal<A> a;
205 #else // defined(C10_PREFER_CUSTOM_THREAD_LOCAL_STORAGE)
206     ::c10::ThreadLocal<A> a([]() {
207       static thread_local A var;
208       return &var;
209     });
210 #endif // defined(C10_PREFER_CUSTOM_THREAD_LOCAL_STORAGE)
211 
212     EXPECT_EQ(a->i, 0);
213     a->i = 1;
214     EXPECT_EQ(a->i, 1);
215     b = true;
216   });
217   t.join();
218 
219   EXPECT_TRUE(b);
220 
221   EXPECT_EQ(ctors, 1);
222   EXPECT_EQ(dtors, 1);
223 }
224 
225 } // namespace
226