xref: /aosp_15_r20/external/pytorch/aten/src/ATen/test/scalar_test.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 
3 #include <iostream>
4 #include <random>
5 #include <c10/core/SymInt.h>
6 // define constants like M_PI and C keywords for MSVC
7 #ifdef _MSC_VER
8 #ifndef _USE_MATH_DEFINES
9 #define _USE_MATH_DEFINES
10 #endif
11 #include <math.h>
12 #endif
13 #include <ATen/ATen.h>
14 #include <ATen/Dispatch.h>
15 
16 // We intentionally test self assignment/move in this file, suppress warnings
17 // on them
18 #ifndef _MSC_VER
19 #pragma GCC diagnostic ignored "-Wpragmas"
20 #pragma GCC diagnostic ignored "-Wunknown-warning-option"
21 #pragma GCC diagnostic ignored "-Wself-move"
22 #endif
23 
24 #ifdef __clang__
25 #pragma clang diagnostic ignored "-Wself-assign-overloaded"
26 #endif
27 
28 using std::cout;
29 using namespace at;
30 
31 template<typename scalar_type>
32 struct Foo {
applyFoo33   static void apply(Tensor a, Tensor b) {
34     scalar_type s = 1;
35     std::stringstream ss;
36     ss << "hello, dispatch: " << a.toString() << s << "\n";
37     auto data = (scalar_type*)a.data_ptr();
38     (void)data;
39   }
40 };
41 template<>
42 struct Foo<Half> {
applyFoo43   static void apply(Tensor a, Tensor b) {}
44 };
45 
test_overflow()46 void test_overflow() {
47   auto s1 = Scalar(M_PI);
48   ASSERT_EQ(s1.toFloat(), static_cast<float>(M_PI));
49   s1.toHalf();
50 
51   s1 = Scalar(100000);
52   ASSERT_EQ(s1.toFloat(), 100000.0);
53   ASSERT_EQ(s1.toInt(), 100000);
54 
55   // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
56   ASSERT_THROW(s1.toHalf(), std::runtime_error);
57 
58   s1 = Scalar(NAN);
59   ASSERT_TRUE(std::isnan(s1.toFloat()));
60   // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
61   ASSERT_THROW(s1.toInt(), std::runtime_error);
62 
63   s1 = Scalar(INFINITY);
64   ASSERT_TRUE(std::isinf(s1.toFloat()));
65   // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
66   ASSERT_THROW(s1.toInt(), std::runtime_error);
67 }
68 
TEST(TestScalar,TestScalar)69 TEST(TestScalar, TestScalar) {
70   manual_seed(123);
71 
72   Scalar what = 257;
73   Scalar bar = 3.0;
74   Half h = bar.toHalf();
75   Scalar h2 = h;
76   cout << "H2: " << h2.toDouble() << " " << what.toFloat() << " "
77        << bar.toDouble() << " " << what.isIntegral(false) << "\n";
78   auto gen = at::detail::getDefaultCPUGenerator();
79   {
80     // See Note [Acquire lock when using random generators]
81     std::lock_guard<std::mutex> lock(gen.mutex());
82     // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
83     ASSERT_NO_THROW(gen.set_current_seed(std::random_device()()));
84   }
85   if (at::hasCUDA()) {
86     auto t2 = zeros({4, 4}, at::kCUDA);
87     cout << &t2 << "\n";
88   }
89   auto t = ones({4, 4});
90 
91   auto wha2 = zeros({4, 4}).add(t).sum();
92   ASSERT_EQ(wha2.item<double>(), 16.0);
93 
94   ASSERT_EQ(t.sizes()[0], 4);
95   ASSERT_EQ(t.sizes()[1], 4);
96   ASSERT_EQ(t.strides()[0], 4);
97   ASSERT_EQ(t.strides()[1], 1);
98 
99   TensorOptions options = dtype(kFloat);
100   Tensor x = randn({1, 10}, options);
101   Tensor prev_h = randn({1, 20}, options);
102   Tensor W_h = randn({20, 20}, options);
103   Tensor W_x = randn({20, 10}, options);
104   Tensor i2h = at::mm(W_x, x.t());
105   Tensor h2h = at::mm(W_h, prev_h.t());
106   Tensor next_h = i2h.add(h2h);
107   next_h = next_h.tanh();
108 
109   // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
110   ASSERT_ANY_THROW(Tensor{}.item());
111 
112   test_overflow();
113 
114   if (at::hasCUDA()) {
115     auto r = next_h.to(at::Device(kCUDA), kFloat, /*non_blocking=*/ false, /*copy=*/ true);
116     ASSERT_TRUE(r.to(at::Device(kCPU), kFloat, /*non_blocking=*/ false, /*copy=*/ true).equal(next_h));
117   }
118   // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
119   ASSERT_NO_THROW(randn({10, 10, 2}, options));
120 
121   // check Scalar.toTensor on Scalars backed by different data types
122   ASSERT_EQ(scalar_to_tensor(bar).scalar_type(), kDouble);
123   ASSERT_EQ(scalar_to_tensor(what).scalar_type(), kLong);
124   ASSERT_EQ(scalar_to_tensor(ones({}).item()).scalar_type(), kDouble);
125 
126   if (x.scalar_type() != ScalarType::Half) {
127     AT_DISPATCH_ALL_TYPES(x.scalar_type(), "foo", [&] {
128       scalar_t s = 1;
129       std::stringstream ss;
130       // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
131       ASSERT_NO_THROW(
132           ss << "hello, dispatch" << x.toString() << s << "\n");
133       auto data = (scalar_t*)x.data_ptr();
134       (void)data;
135     });
136   }
137 
138   // test direct C-scalar type conversions
139   {
140     auto x = ones({1, 2}, options);
141     // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
142     ASSERT_ANY_THROW(x.item<float>());
143   }
144   auto float_one = ones({}, options);
145   ASSERT_EQ(float_one.item<float>(), 1);
146   ASSERT_EQ(float_one.item<int32_t>(), 1);
147   ASSERT_EQ(float_one.item<at::Half>(), 1);
148 }
149 
TEST(TestScalar,TestConj)150 TEST(TestScalar, TestConj) {
151   Scalar int_scalar = 257;
152   Scalar float_scalar = 3.0;
153   Scalar complex_scalar = c10::complex<double>(2.3, 3.5);
154 
155   ASSERT_EQ(int_scalar.conj().toInt(), 257);
156   ASSERT_EQ(float_scalar.conj().toDouble(), 3.0);
157   ASSERT_EQ(complex_scalar.conj().toComplexDouble(), c10::complex<double>(2.3, -3.5));
158 }
159 
TEST(TestScalar,TestEqual)160 TEST(TestScalar, TestEqual) {
161   ASSERT_FALSE(Scalar(1.0).equal(false));
162   ASSERT_FALSE(Scalar(1.0).equal(true));
163   ASSERT_FALSE(Scalar(true).equal(1.0));
164   ASSERT_TRUE(Scalar(true).equal(true));
165 
166   ASSERT_TRUE(Scalar(c10::complex<double>{2.0, 5.0}).equal(c10::complex<double>{2.0, 5.0}));
167   ASSERT_TRUE(Scalar(c10::complex<double>{2.0, 0}).equal(2.0));
168   ASSERT_TRUE(Scalar(c10::complex<double>{2.0, 0}).equal(2));
169 
170   ASSERT_TRUE(Scalar(2.0).equal(c10::complex<double>{2.0, 0.0}));
171   ASSERT_FALSE(Scalar(2.0).equal(c10::complex<double>{2.0, 4.0}));
172   ASSERT_FALSE(Scalar(2.0).equal(3.0));
173   ASSERT_TRUE(Scalar(2.0).equal(2));
174 
175   ASSERT_TRUE(Scalar(2).equal(c10::complex<double>{2.0, 0}));
176   ASSERT_TRUE(Scalar(2).equal(2));
177   ASSERT_TRUE(Scalar(2).equal(2.0));
178 }
179 
TEST(TestScalar,TestFormatting)180 TEST(TestScalar, TestFormatting) {
181   auto format = [] (Scalar a) {
182     std::ostringstream str;
183     str << a;
184     return str.str();
185   };
186   ASSERT_EQ("3", format(Scalar(3)));
187   ASSERT_EQ("3.1", format(Scalar(3.1)));
188   ASSERT_EQ("true", format(Scalar(true)));
189   ASSERT_EQ("false", format(Scalar(false)));
190   ASSERT_EQ("(2,3.1)", format(Scalar(c10::complex<double>(2.0, 3.1))));
191   ASSERT_EQ("(2,3.1)", format(Scalar(c10::complex<float>(2.0, 3.1))));
192   ASSERT_EQ("4", format(Scalar(Scalar(4).toSymInt())));
193 }
194