xref: /aosp_15_r20/external/pytorch/test/cpp/api/transformer.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 
3 #include <torch/torch.h>
4 
5 #include <test/cpp/api/support.h>
6 
7 using namespace torch::nn;
8 
9 struct TransformerTest : torch::test::SeedingFixture {};
10 
11 // a generic function to set constants for parameters so we have fixed result
12 // for deterministic test
13 template <typename Model>
set_parameter_to_constants(Model & model,const torch::TensorOptions & tensor_options)14 void set_parameter_to_constants(
15     Model& model,
16     const torch::TensorOptions& tensor_options) {
17   torch::NoGradGuard guard;
18   for (auto& p : model->parameters()) {
19     auto sz = p.view(-1).size(0);
20     p.copy_(torch::cos(torch::arange(0, sz, tensor_options).view(p.sizes())));
21   }
22 }
23 
24 // a generic function to provide consistent encoder/decoder layer for all the
25 // transformer tests
26 template <typename T_LAYER, typename T_OPTIONS>
get_a_test_layer(const torch::TensorOptions & tensor_options,bool use_callable_activation)27 T_LAYER get_a_test_layer(
28     const torch::TensorOptions& tensor_options,
29     bool use_callable_activation) {
30   int64_t d_model = 4;
31   int64_t nhead = 2;
32   int64_t dim_feedforward = 16;
33   double dropout = 0.0;
34 
35   // activation is always ReLU here and it can be adjusted later depending on
36   // the usage
37   T_LAYER layer(T_OPTIONS(d_model, nhead)
38                     .dim_feedforward(dim_feedforward)
39                     .dropout(dropout));
40   if (tensor_options.device() == torch::kCUDA) {
41     layer->to(torch::kCUDA);
42   }
43   if (use_callable_activation) {
44     layer.get()->options.activation(
45         [&](const torch::Tensor& t) { return torch::nn::functional::relu(t); });
46   }
47 
48   // set constant weights of the model
49   set_parameter_to_constants<T_LAYER>(layer, tensor_options);
50 
51   return layer;
52 }
53 
transformer_encoder_layer_test_helper(bool is_cuda,bool use_callable_activation)54 void transformer_encoder_layer_test_helper(
55     bool is_cuda,
56     bool use_callable_activation) {
57   // this is a deterministic test for TransformerEncoderLayer
58   torch::Device device = is_cuda ? torch::kCUDA : torch::kCPU;
59   torch::TensorOptions tensor_options =
60       torch::TensorOptions().dtype(torch::kFloat32).device(device);
61 
62   TransformerEncoderLayer model =
63       get_a_test_layer<TransformerEncoderLayer, TransformerEncoderLayerOptions>(
64           tensor_options, use_callable_activation);
65 
66   // relu test case 1
67   torch::Tensor encoder_input =
68       torch::tensor({{{20, 30, 40, 50}}}, tensor_options);
69   torch::Tensor result = model(encoder_input).detach();
70   torch::Tensor ref_output = torch::tensor(
71       {{{2.258703, 0.127985, -0.697881, 0.170862}}}, tensor_options);
72   ASSERT_EQ(result.sizes(), ref_output.sizes());
73   ASSERT_TRUE(
74       torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true));
75 
76   // all 0 values are NOT masked. This should't mask anything
77   torch::Tensor mask = torch::tensor({{0}}, tensor_options) == 1;
78   result = model(
79                encoder_input,
80                /*src_mask=*/torch::Tensor{},
81                /*src_key_padding_mask=*/mask)
82                .detach();
83   ASSERT_EQ(result.sizes(), ref_output.sizes());
84   ASSERT_TRUE(
85       torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true));
86 
87   // all 1 values are masked. Since there is only 1 input embedding this will
88   // result in nan.
89   mask = torch::tensor({{1}}, tensor_options) == 1;
90   result = model(
91                encoder_input,
92                /*src_mask=*/torch::Tensor{},
93                /*src_key_padding_mask=*/mask)
94                .detach();
95   ASSERT_TRUE(torch::isnan(result).all().item().to<bool>());
96 
97   // relu test case 2
98   encoder_input =
99       torch::tensor({{{1, 2, 3, 4}}, {{5, 6, 7, 8}}}, tensor_options);
100   result = model(encoder_input).detach();
101   ref_output = torch::tensor(
102       {{{2.272644, 0.119035, -0.691669, 0.153486}},
103        {{2.272644, 0.119035, -0.691669, 0.153486}}},
104       tensor_options);
105   ASSERT_EQ(result.sizes(), ref_output.sizes());
106   ASSERT_TRUE(
107       torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true));
108 
109   // all 0 values are NOT masked
110   mask = torch::tensor({{0, 0}}, tensor_options) == 1;
111   result = model(
112                encoder_input,
113                /*src_mask=*/torch::Tensor{},
114                /*src_key_padding_mask=*/mask)
115                .detach();
116   ASSERT_EQ(result.sizes(), ref_output.sizes());
117   ASSERT_TRUE(
118       torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true));
119 
120   // mask with 1 and 0
121   mask = torch::tensor({{1, 0}}, tensor_options) == 1;
122   result = model(
123                encoder_input,
124                /*src_mask=*/torch::Tensor{},
125                /*src_key_padding_mask=*/mask)
126                .detach();
127   ref_output = torch::tensor(
128       {{{2.301516, 0.092249, -0.679101, 0.103088}},
129        {{2.301516, 0.092249, -0.679101, 0.103088}}},
130       tensor_options);
131   ASSERT_EQ(result.sizes(), ref_output.sizes());
132   ASSERT_TRUE(
133       torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true));
134 
135   // relu test case 3
136   encoder_input = torch::tensor(
137       {{{0.7462, 0.6653, 0.5679, 0.4891}, {0.5387, 0.1655, 0.3565, 0.0471}},
138        {{0.8335, 0.2799, 0.5031, 0.2947}, {0.1402, 0.0318, 0.7636, 0.1346}},
139        {{0.6333, 0.9344, 0.1376, 0.9938}, {0.8924, 0.2872, 0.6692, 0.2944}},
140        {{0.9897, 0.6915, 0.3154, 0.1733}, {0.8645, 0.3513, 0.3064, 0.0767}},
141        {{0.8117, 0.2366, 0.4838, 0.7881}, {0.3718, 0.4945, 0.9511, 0.0864}}},
142       tensor_options);
143   result = model(encoder_input).detach();
144   ref_output = torch::tensor(
145       {{{2.428589, 0.020835, -0.602055, -0.085249},
146         {2.427987, 0.021213, -0.602496, -0.084103}},
147        {{2.424689, 0.019155, -0.604793, -0.085672},
148         {2.413863, 0.022211, -0.612486, -0.072490}},
149        {{2.433774, 0.021598, -0.598343, -0.087548},
150         {2.425104, 0.019748, -0.604515, -0.084839}},
151        {{2.436185, 0.022682, -0.596625, -0.087261},
152         {2.433556, 0.021891, -0.598509, -0.086832}},
153        {{2.416246, 0.017512, -0.610712, -0.082961},
154         {2.422901, 0.024187, -0.606178, -0.074929}}},
155       tensor_options);
156   ASSERT_EQ(result.sizes(), ref_output.sizes());
157   ASSERT_TRUE(
158       torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true));
159 
160   // all 0 values are NOT masked
161   mask = torch::zeros({2, 5}, tensor_options) == 1;
162   result = model(
163                encoder_input,
164                /*src_mask=*/torch::Tensor{},
165                /*src_key_padding_mask=*/mask)
166                .detach();
167   ASSERT_EQ(result.sizes(), ref_output.sizes());
168   ASSERT_TRUE(
169       torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true));
170 
171   // mask with 0s and 1s
172   mask[0][1] = 1;
173   mask[1][3] = 1;
174   mask[1][4] = 1;
175   result = model(
176                encoder_input,
177                /*src_mask=*/torch::Tensor{},
178                /*src_key_padding_mask=*/mask)
179                .detach();
180   ref_output = torch::tensor(
181       {{{2.429026, 0.020793, -0.601741, -0.085642},
182         {2.428811, 0.021445, -0.601912, -0.084252}},
183        {{2.425009, 0.019155, -0.604566, -0.085899},
184         {2.415408, 0.02249, -0.611415, -0.073}},
185        {{2.434199, 0.021682, -0.598039, -0.087699},
186         {2.42598, 0.019941, -0.603896, -0.085091}},
187        {{2.436457, 0.022736, -0.59643, -0.08736},
188         {2.434021, 0.022093, -0.598179, -0.08679}},
189        {{2.416531, 0.017498, -0.610513, -0.083181},
190         {2.4242, 0.024653, -0.605266, -0.074959}}},
191       tensor_options);
192   ASSERT_EQ(result.sizes(), ref_output.sizes());
193   ASSERT_TRUE(
194       torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true));
195 
196   // gelu test case 1
197   model.get()->options.activation(torch::kGELU);
198   encoder_input = torch::tensor({{{20, 30, 40, 50}}}, tensor_options);
199   result = model(encoder_input).detach();
200   ref_output = torch::tensor(
201       {{{2.249815, 0.131006, -0.702199, 0.177868}}}, tensor_options);
202   ASSERT_EQ(result.sizes(), ref_output.sizes());
203   ASSERT_TRUE(
204       torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true));
205 
206   // gelu test case 2
207   encoder_input = torch::tensor(
208       {{{0.7462, 0.6653, 0.5679, 0.4891}, {0.5387, 0.1655, 0.3565, 0.0471}},
209        {{0.8335, 0.2799, 0.5031, 0.2947}, {0.1402, 0.0318, 0.7636, 0.1346}},
210        {{0.6333, 0.9344, 0.1376, 0.9938}, {0.8924, 0.2872, 0.6692, 0.2944}},
211        {{0.9897, 0.6915, 0.3154, 0.1733}, {0.8645, 0.3513, 0.3064, 0.0767}},
212        {{0.8117, 0.2366, 0.4838, 0.7881}, {0.3718, 0.4945, 0.9511, 0.0864}}},
213       tensor_options);
214   result = model(encoder_input);
215   ref_output = torch::tensor(
216       {{{2.42163188, 0.03227153, -0.60714219, -0.05908082},
217         {2.42151276, 0.03302179, -0.60722523, -0.05762651}},
218        {{2.41926761, 0.02974034, -0.60879519, -0.0621269},
219         {2.41626395, 0.03539356, -0.61087842, -0.04978623}},
220        {{2.42382808, 0.03218872, -0.6055963, -0.06073591},
221         {2.41983477, 0.03085259, -0.60840145, -0.06046414}},
222        {{2.42500749, 0.03328855, -0.60476388, -0.0595334},
223         {2.4237977, 0.03290575, -0.60561789, -0.05940082}},
224        {{2.41383916, 0.02686345, -0.61256377, -0.06380707},
225         {2.42000277, 0.03800944, -0.60824798, -0.04754947}}},
226       tensor_options);
227   ASSERT_EQ(result.sizes(), ref_output.sizes());
228   ASSERT_TRUE(
229       torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true));
230 }
231 
TEST_F(TransformerTest,TransformerEncoderLayer)232 TEST_F(TransformerTest, TransformerEncoderLayer) {
233   transformer_encoder_layer_test_helper(
234       /*is_cuda=*/false, /*use_callable_activation=*/false);
235   transformer_encoder_layer_test_helper(
236       /*is_cuda=*/false, /*use_callable_activation=*/true);
237 }
238 
TEST_F(TransformerTest,TransformerEncoderLayer_CUDA)239 TEST_F(TransformerTest, TransformerEncoderLayer_CUDA) {
240   transformer_encoder_layer_test_helper(
241       /*is_cuda=*/true, /*use_callable_activation=*/false);
242   transformer_encoder_layer_test_helper(
243       /*is_cuda=*/true, /*use_callable_activation=*/true);
244 }
245 
transformer_decoder_layer_test_helper(bool is_cuda,bool use_callable_activation)246 void transformer_decoder_layer_test_helper(
247     bool is_cuda,
248     bool use_callable_activation) {
249   torch::Device device = is_cuda ? torch::kCUDA : torch::kCPU;
250   torch::TensorOptions tensor_options =
251       torch::TensorOptions().dtype(torch::kFloat32).device(device);
252 
253   TransformerDecoderLayer model =
254       get_a_test_layer<TransformerDecoderLayer, TransformerDecoderLayerOptions>(
255           tensor_options, use_callable_activation);
256 
257   // deterministic input
258   torch::Tensor decoder_input =
259       torch::tensor({{{20, 30, 40, 50}}}, tensor_options);
260   torch::Tensor memory_input =
261       torch::tensor({{{60, 70, 80, 90}}}, tensor_options);
262   torch::Tensor result = model(decoder_input, memory_input).detach();
263   torch::Tensor ref_output = torch::tensor(
264       {{{2.314351, 0.094805, -0.671322, 0.101977}}}, tensor_options);
265   ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
266   ASSERT_TRUE(torch::allclose(
267       result,
268       ref_output,
269       1e-7,
270       1e-5,
271       /*equal_nan=*/true));
272 
273   // deterministic input
274   decoder_input =
275       torch::tensor({{{9, 10, 11, 12}}, {{11, 12, 13, 14}}}, tensor_options);
276   memory_input = torch::tensor({{{1, 2, 3, 4}}}, tensor_options);
277   result = model(decoder_input, memory_input).detach();
278   ref_output = torch::tensor(
279       {{{2.422245, 0.051716, -0.606338, -0.024756}},
280        {{2.422245, 0.051716, -0.606338, -0.024756}}},
281       tensor_options);
282   ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
283   ASSERT_TRUE(torch::allclose(
284       result,
285       ref_output,
286       1e-7,
287       1e-5,
288       /*equal_nan=*/true));
289 
290   // deterministic input
291   decoder_input =
292       torch::tensor({{{1, 2, 3, 4}}, {{5, 6, 7, 8}}}, tensor_options);
293   memory_input =
294       torch::tensor({{{9, 10, 11, 12}}, {{11, 12, 13, 14}}}, tensor_options);
295   result = model(decoder_input, memory_input).detach();
296   ref_output = torch::tensor(
297       {{{2.343536, 0.085561, -0.654954, 0.074991}},
298        {{2.343536, 0.085561, -0.654954, 0.074991}}},
299       tensor_options);
300   ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
301   ASSERT_TRUE(torch::allclose(
302       result,
303       ref_output,
304       1e-7,
305       1e-5,
306       /*equal_nan=*/true));
307 
308   // deterministic input
309   decoder_input = torch::tensor(
310       {{{0.4517, 0.6793, 0.5313, 0.0034}, {0.2678, 0.3677, 0.4459, 0.7166}},
311        {{0.8100, 0.3716, 0.4096, 0.1976}, {0.6958, 0.8844, 0.6081, 0.8315}},
312        {{0.0494, 0.9343, 0.5955, 0.3830}, {0.5404, 0.3464, 0.9378, 0.6200}}},
313       tensor_options);
314   memory_input = torch::tensor(
315       {{{0.7462, 0.6653, 0.5679, 0.4891}, {0.5387, 0.1655, 0.3565, 0.0471}},
316        {{0.8335, 0.2799, 0.5031, 0.2947}, {0.1402, 0.0318, 0.7636, 0.1346}},
317        {{0.6333, 0.9344, 0.1376, 0.9938}, {0.8924, 0.2872, 0.6692, 0.2944}},
318        {{0.9897, 0.6915, 0.3154, 0.1733}, {0.8645, 0.3513, 0.3064, 0.0767}},
319        {{0.8117, 0.2366, 0.4838, 0.7881}, {0.3718, 0.4945, 0.9511, 0.0864}}},
320       tensor_options);
321   result = model(decoder_input, memory_input).detach();
322   ref_output = torch::tensor(
323       {{{2.430065, 0.027862, -0.601136, -0.073096},
324         {2.431935, 0.028907, -0.599809, -0.072488}},
325        {{2.428457, 0.027053, -0.602275, -0.073462},
326         {2.431970, 0.029387, -0.599789, -0.071621}},
327        {{2.431934, 0.028196, -0.599802, -0.073809},
328         {2.432306, 0.028858, -0.599542, -0.072846}}},
329       tensor_options);
330   ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
331   ASSERT_TRUE(torch::allclose(
332       result,
333       ref_output,
334       1e-7,
335       1e-5,
336       /*equal_nan=*/true));
337 
338   // key_padding_mask
339   torch::Tensor t_mask = {};
340   torch::Tensor m_mask = {};
341   torch::Tensor key_padding_mask = torch::zeros({2, 3}, tensor_options) == 1;
342   result = model(decoder_input, memory_input, t_mask, m_mask, key_padding_mask)
343                .detach();
344   ref_output = torch::tensor(
345       {{{2.430065, 0.027862, -0.601136, -0.073096},
346         {2.431935, 0.028907, -0.599809, -0.072488}},
347        {{2.428457, 0.027053, -0.602275, -0.073462},
348         {2.431970, 0.029387, -0.599789, -0.071621}},
349        {{2.431934, 0.028196, -0.599802, -0.073809},
350         {2.432306, 0.028858, -0.599542, -0.072846}}},
351       tensor_options);
352   ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
353   ASSERT_TRUE(torch::allclose(
354       result,
355       ref_output,
356       1e-7,
357       1e-5,
358       /*equal_nan=*/true));
359 
360   // key_padding_mask
361   key_padding_mask[0][2] = 1;
362   key_padding_mask[1][1] = 1;
363   key_padding_mask[1][2] = 1;
364   result = model(decoder_input, memory_input, t_mask, m_mask, key_padding_mask)
365                .detach();
366   ref_output = torch::tensor(
367       {{{2.430025, 0.027643, -0.601164, -0.073476},
368         {2.4323, 0.029375, -0.599553, -0.071881}},
369        {{2.428523, 0.026838, -0.602226, -0.07391},
370         {2.432634, 0.029842, -0.599318, -0.071253}},
371        {{2.432278, 0.028152, -0.599555, -0.074139},
372         {2.432659, 0.029244, -0.599294, -0.072382}}},
373       tensor_options);
374   ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
375   ASSERT_TRUE(torch::allclose(
376       result,
377       ref_output,
378       1e-7,
379       1e-5,
380       /*equal_nan=*/true));
381 
382   // memory_key_padding_mask
383   torch::Tensor t_key_padding_mask = {};
384   key_padding_mask = torch::zeros({2, 5}, tensor_options) == 1;
385   result = model(
386                decoder_input,
387                memory_input,
388                t_mask,
389                m_mask,
390                t_key_padding_mask,
391                key_padding_mask)
392                .detach();
393   ref_output = torch::tensor(
394       {{{2.430065, 0.027862, -0.601136, -0.073096},
395         {2.431935, 0.028907, -0.599809, -0.072488}},
396        {{2.428457, 0.027053, -0.602275, -0.073462},
397         {2.431970, 0.029387, -0.599789, -0.071621}},
398        {{2.431934, 0.028196, -0.599802, -0.073809},
399         {2.432306, 0.028858, -0.599542, -0.072846}}},
400       tensor_options);
401   ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
402   ASSERT_TRUE(torch::allclose(
403       result,
404       ref_output,
405       1e-7,
406       1e-5,
407       /*equal_nan=*/true));
408 
409   // memory_key_padding_mask
410   key_padding_mask[0][4] = 1;
411   key_padding_mask[1][3] = 1;
412   key_padding_mask[1][4] = 1;
413   result = model(
414                decoder_input,
415                memory_input,
416                t_mask,
417                m_mask,
418                t_key_padding_mask,
419                key_padding_mask)
420                .detach();
421   ref_output = torch::tensor(
422       {{{2.429757, 0.027358, -0.601351, -0.073816},
423         {2.432692, 0.028583, -0.599263, -0.073634}},
424        {{2.428247, 0.02662, -0.602419, -0.074123},
425         {2.432657, 0.029055, -0.599293, -0.072732}},
426        {{2.431515, 0.027687, -0.600096, -0.074459},
427         {2.433075, 0.028543, -0.598987, -0.073985}}},
428       tensor_options);
429   ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
430   ASSERT_TRUE(torch::allclose(
431       result,
432       ref_output,
433       1e-7,
434       1e-5,
435       /*equal_nan=*/true));
436 }
437 
TEST_F(TransformerTest,TransformerDecoderLayer)438 TEST_F(TransformerTest, TransformerDecoderLayer) {
439   transformer_decoder_layer_test_helper(
440       /*is_cuda=*/false, /*use_callable_activation=*/false);
441   transformer_decoder_layer_test_helper(
442       /*is_cuda=*/false, /*use_callable_activation=*/true);
443 }
444 
TEST_F(TransformerTest,TransformerDecoderLayer_CUDA)445 TEST_F(TransformerTest, TransformerDecoderLayer_CUDA) {
446   transformer_decoder_layer_test_helper(
447       /*is_cuda=*/true, /*use_callable_activation=*/false);
448   transformer_decoder_layer_test_helper(
449       /*is_cuda=*/true, /*use_callable_activation=*/true);
450 }
451 
transformer_decoder_layer_test_helper_gelu(bool is_cuda,bool use_callable_activation)452 void transformer_decoder_layer_test_helper_gelu(
453     bool is_cuda,
454     bool use_callable_activation) {
455   torch::Device device = is_cuda ? torch::kCUDA : torch::kCPU;
456   torch::TensorOptions tensor_options =
457       torch::TensorOptions().dtype(torch::kFloat32).device(device);
458 
459   TransformerDecoderLayer model =
460       get_a_test_layer<TransformerDecoderLayer, TransformerDecoderLayerOptions>(
461           tensor_options, use_callable_activation);
462   if (use_callable_activation) {
463     model.get()->options.activation(
464         [&](const torch::Tensor& t) { return torch::nn::functional::gelu(t); });
465   } else {
466     model.get()->options.activation(torch::kGELU);
467   }
468 
469   // deterministic input
470   torch::Tensor decoder_input =
471       torch::tensor({{{20, 30, 40, 50}}}, tensor_options);
472   torch::Tensor memory_input =
473       torch::tensor({{{60, 70, 80, 90}}}, tensor_options);
474   torch::Tensor result = model(decoder_input, memory_input).detach();
475   torch::Tensor ref_output = torch::tensor(
476       {{{2.306435, 0.095946, -0.675796, 0.10687}}}, tensor_options);
477   ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
478   ASSERT_TRUE(torch::allclose(
479       result,
480       ref_output,
481       1e-7,
482       1e-5,
483       /*equal_nan=*/true));
484 
485   // deterministic input
486   decoder_input =
487       torch::tensor({{{9, 10, 11, 12}}, {{11, 12, 13, 14}}}, tensor_options);
488   memory_input = torch::tensor({{{1, 2, 3, 4}}}, tensor_options);
489   result = model(decoder_input, memory_input).detach();
490   ref_output = torch::tensor(
491       {{{2.415448, 0.054389, -0.610932, -0.0156613}},
492        {{2.415448, 0.054389, -0.610932, -0.0156613}}},
493       tensor_options);
494   ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
495   ASSERT_TRUE(torch::allclose(
496       result,
497       ref_output,
498       1e-7,
499       1e-5,
500       /*equal_nan=*/true));
501 
502   // deterministic input
503   decoder_input =
504       torch::tensor({{{1, 2, 3, 4}}, {{5, 6, 7, 8}}}, tensor_options);
505   memory_input =
506       torch::tensor({{{9, 10, 11, 12}}, {{11, 12, 13, 14}}}, tensor_options);
507   result = model(decoder_input, memory_input).detach();
508   ref_output = torch::tensor(
509       {{{2.338531, 0.087709, -0.65776, 0.080646}},
510        {{2.338531, 0.087709, -0.65776, 0.080646}}},
511       tensor_options);
512   ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
513   ASSERT_TRUE(torch::allclose(
514       result,
515       ref_output,
516       1e-7,
517       1e-5,
518       /*equal_nan=*/true));
519 
520   // deterministic input
521   decoder_input = torch::tensor(
522       {{{0.4517, 0.6793, 0.5313, 0.0034}, {0.2678, 0.3677, 0.4459, 0.7166}},
523        {{0.8100, 0.3716, 0.4096, 0.1976}, {0.6958, 0.8844, 0.6081, 0.8315}},
524        {{0.0494, 0.9343, 0.5955, 0.3830}, {0.5404, 0.3464, 0.9378, 0.6200}}},
525       tensor_options);
526   memory_input = torch::tensor(
527       {{{0.7462, 0.6653, 0.5679, 0.4891}, {0.5387, 0.1655, 0.3565, 0.0471}},
528        {{0.8335, 0.2799, 0.5031, 0.2947}, {0.1402, 0.0318, 0.7636, 0.1346}},
529        {{0.6333, 0.9344, 0.1376, 0.9938}, {0.8924, 0.2872, 0.6692, 0.2944}},
530        {{0.9897, 0.6915, 0.3154, 0.1733}, {0.8645, 0.3513, 0.3064, 0.0767}},
531        {{0.8117, 0.2366, 0.4838, 0.7881}, {0.3718, 0.4945, 0.9511, 0.0864}}},
532       tensor_options);
533   result = model(decoder_input, memory_input).detach();
534   ref_output = torch::tensor(
535       {{{2.42049104, 0.03443088, -0.60793706, -0.05436271},
536         {2.42210631, 0.03546578, -0.60679895, -0.05357488}},
537        {{2.41907674, 0.0336104, -0.60892977, -0.05490462},
538         {2.42216881, 0.03586554, -0.6067524, -0.05289126}},
539        {{2.42205716, 0.03488046, -0.60683681, -0.05460596},
540         {2.42240309, 0.0354595, -0.60659063, -0.05378816}}},
541       tensor_options);
542   ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
543   ASSERT_TRUE(torch::allclose(
544       result,
545       ref_output,
546       1e-7,
547       1e-5,
548       /*equal_nan=*/true));
549 }
550 
TEST_F(TransformerTest,TransformerDecoderLayer_gelu)551 TEST_F(TransformerTest, TransformerDecoderLayer_gelu) {
552   transformer_decoder_layer_test_helper_gelu(
553       /*is_cuda=*/false, /*use_callable_activation=*/false);
554   transformer_decoder_layer_test_helper_gelu(
555       /*is_cuda=*/false, /*use_callable_activation=*/true);
556 }
557 
TEST_F(TransformerTest,TransformerDecoderLayer_gelu_CUDA)558 TEST_F(TransformerTest, TransformerDecoderLayer_gelu_CUDA) {
559   transformer_decoder_layer_test_helper_gelu(
560       /*is_cuda=*/true, /*use_callable_activation=*/false);
561   transformer_decoder_layer_test_helper_gelu(
562       /*is_cuda=*/true, /*use_callable_activation=*/true);
563 }
564 
transformer_encoder_test_helper(bool is_cuda,bool use_callable_activation)565 void transformer_encoder_test_helper(
566     bool is_cuda,
567     bool use_callable_activation) {
568   // this is a deterministic test for TransformerEncoderLayer
569   torch::Device device = is_cuda ? torch::kCUDA : torch::kCPU;
570   torch::TensorOptions tensor_options =
571       torch::TensorOptions().dtype(torch::kFloat32).device(device);
572 
573   TransformerEncoderLayer encoder_layer =
574       get_a_test_layer<TransformerEncoderLayer, TransformerEncoderLayerOptions>(
575           tensor_options, use_callable_activation);
576 
577   TransformerEncoder model(TransformerEncoderOptions(encoder_layer, 1));
578   if (is_cuda) {
579     model->to(torch::kCUDA);
580   }
581 
582   torch::Tensor encoder_input = torch::tensor(
583       {{{0.7462, 0.6653, 0.5679, 0.4891}, {0.5387, 0.1655, 0.3565, 0.0471}},
584        {{0.8335, 0.2799, 0.5031, 0.2947}, {0.1402, 0.0318, 0.7636, 0.1346}},
585        {{0.6333, 0.9344, 0.1376, 0.9938}, {0.8924, 0.2872, 0.6692, 0.2944}},
586        {{0.9897, 0.6915, 0.3154, 0.1733}, {0.8645, 0.3513, 0.3064, 0.0767}},
587        {{0.8117, 0.2366, 0.4838, 0.7881}, {0.3718, 0.4945, 0.9511, 0.0864}}},
588       tensor_options);
589   torch::Tensor result = model(encoder_input).detach();
590   torch::Tensor ref_output = torch::tensor(
591       {{{2.428589, 0.020835, -0.602055, -0.085249},
592         {2.427987, 0.021213, -0.602496, -0.084103}},
593        {{2.424689, 0.019155, -0.604793, -0.085672},
594         {2.413863, 0.022211, -0.612486, -0.072490}},
595        {{2.433774, 0.021598, -0.598343, -0.087548},
596         {2.425104, 0.019748, -0.604515, -0.084839}},
597        {{2.436185, 0.022682, -0.596625, -0.087261},
598         {2.433556, 0.021891, -0.598509, -0.086832}},
599        {{2.416246, 0.017512, -0.610712, -0.082961},
600         {2.422901, 0.024187, -0.606178, -0.074929}}},
601       tensor_options);
602   ASSERT_EQ(result.sizes(), ref_output.sizes());
603   ASSERT_TRUE(
604       torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true));
605 
606   // all 0 values are NOT masked
607   torch::Tensor mask = torch::zeros({2, 5}, tensor_options) == 1;
608   result = model(
609                encoder_input,
610                /*src_mask=*/torch::Tensor{},
611                /*src_key_padding_mask=*/mask)
612                .detach();
613   ASSERT_EQ(result.sizes(), ref_output.sizes());
614   ASSERT_TRUE(
615       torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true));
616 
617   // mask with 0s and 1s
618   mask[0][1] = 1;
619   mask[1][3] = 1;
620   mask[1][4] = 1;
621   result = model(
622                encoder_input,
623                /*src_mask=*/torch::Tensor{},
624                /*src_key_padding_mask=*/mask)
625                .detach();
626   ref_output = torch::tensor(
627       {{{2.429026, 0.020793, -0.601741, -0.085642},
628         {2.428811, 0.021445, -0.601912, -0.084252}},
629        {{2.425009, 0.019155, -0.604566, -0.085899},
630         {2.415408, 0.02249, -0.611415, -0.073}},
631        {{2.434199, 0.021682, -0.598039, -0.087699},
632         {2.42598, 0.019941, -0.603896, -0.085091}},
633        {{2.436457, 0.022736, -0.59643, -0.08736},
634         {2.434021, 0.022093, -0.598179, -0.08679}},
635        {{2.416531, 0.017498, -0.610513, -0.083181},
636         {2.4242, 0.024653, -0.605266, -0.074959}}},
637       tensor_options);
638   ASSERT_EQ(result.sizes(), ref_output.sizes());
639   ASSERT_TRUE(
640       torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true));
641 
642   // test case 2, multiple layers no norm
643   model = TransformerEncoder(TransformerEncoderOptions(encoder_layer, 2));
644   if (is_cuda) {
645     model->to(torch::kCUDA);
646   }
647   result = model(
648                encoder_input,
649                /*src_mask=*/torch::Tensor{},
650                /*src_key_padding_mask=*/mask)
651                .detach();
652   ref_output = torch::tensor(
653       {{{2.419051, 0.017446, -0.608738, -0.085003},
654         {2.419102, 0.017452, -0.608703, -0.085026}},
655        {{2.419043, 0.017445, -0.608744, -0.084999},
656         {2.419052, 0.017446, -0.608738, -0.085004}},
657        {{2.419067, 0.017448, -0.608727, -0.085010},
658         {2.419098, 0.017452, -0.608706, -0.085024}},
659        {{2.419072, 0.017449, -0.608724, -0.085012},
660         {2.419119, 0.017455, -0.608691, -0.085034}},
661        {{2.419019, 0.017442, -0.608761, -0.084989},
662         {2.419075, 0.017449, -0.608722, -0.085014}}},
663       tensor_options);
664   ASSERT_EQ(result.sizes(), ref_output.sizes());
665   ASSERT_TRUE(
666       torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true));
667 
668   model = TransformerEncoder(TransformerEncoderOptions(encoder_layer, 6));
669   if (is_cuda) {
670     model->to(torch::kCUDA);
671   }
672   result = model(
673                encoder_input,
674                /*src_mask=*/torch::Tensor{},
675                /*src_key_padding_mask=*/mask)
676                .detach();
677   ref_output = torch::tensor(
678       {{{2.419101, 0.017453, -0.608703, -0.085025},
679         {2.419101, 0.017453, -0.608704, -0.085025}},
680        {{2.419101, 0.017453, -0.608703, -0.085025},
681         {2.419101, 0.017453, -0.608704, -0.085025}},
682        {{2.419101, 0.017453, -0.608703, -0.085025},
683         {2.419101, 0.017453, -0.608704, -0.085025}},
684        {{2.419101, 0.017453, -0.608703, -0.085025},
685         {2.419101, 0.017453, -0.608704, -0.085025}},
686        {{2.419101, 0.017453, -0.608703, -0.085025},
687         {2.419101, 0.017453, -0.608704, -0.085025}}},
688       tensor_options);
689   ASSERT_EQ(result.sizes(), ref_output.sizes());
690   ASSERT_TRUE(
691       torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true));
692 
693   // test case 3, multiple layers with norm
694   LayerNorm norm(LayerNormOptions({encoder_layer.get()->options.d_model()}));
695   model = TransformerEncoder(
696       TransformerEncoderOptions(encoder_layer, 2).norm(AnyModule(norm)));
697   if (is_cuda) {
698     model->to(torch::kCUDA);
699   }
700   result = model(
701                encoder_input,
702                /*src_mask=*/torch::Tensor{},
703                /*src_key_padding_mask=*/mask)
704                .detach();
705   ref_output = torch::tensor(
706       {{{1.695949, -0.357635, -0.893077, -0.445238},
707         {1.695955, -0.357639, -0.893050, -0.445266}},
708        {{1.695948, -0.357634, -0.893082, -0.445233},
709         {1.695950, -0.357635, -0.893077, -0.445238}},
710        {{1.695951, -0.357636, -0.893069, -0.445246},
711         {1.695955, -0.357639, -0.893052, -0.445264}},
712        {{1.695952, -0.357636, -0.893066, -0.445249},
713         {1.695957, -0.357641, -0.893041, -0.445276}},
714        {{1.695946, -0.357632, -0.893095, -0.445220},
715         {1.695952, -0.357637, -0.893065, -0.445251}}},
716       tensor_options);
717   ASSERT_EQ(result.sizes(), ref_output.sizes());
718   ASSERT_TRUE(
719       torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true));
720 
721   model = TransformerEncoder(
722       TransformerEncoderOptions(encoder_layer, 6).norm(AnyModule(norm)));
723   if (is_cuda) {
724     model->to(torch::kCUDA);
725   }
726   result = model(
727                encoder_input,
728                /*src_mask=*/torch::Tensor{},
729                /*src_key_padding_mask=*/mask)
730                .detach();
731   ref_output = torch::tensor(
732       {{{1.695955, -0.357639, -0.893051, -0.445265},
733         {1.695955, -0.357639, -0.893051, -0.445265}},
734        {{1.695955, -0.357639, -0.893051, -0.445265},
735         {1.695955, -0.357639, -0.893051, -0.445265}},
736        {{1.695955, -0.357639, -0.893051, -0.445265},
737         {1.695955, -0.357639, -0.893051, -0.445265}},
738        {{1.695955, -0.357639, -0.893051, -0.445265},
739         {1.695955, -0.357639, -0.893051, -0.445265}},
740        {{1.695955, -0.357639, -0.893051, -0.445265},
741         {1.695955, -0.357639, -0.893051, -0.445265}}},
742       tensor_options);
743   ASSERT_EQ(result.sizes(), ref_output.sizes());
744   ASSERT_TRUE(
745       torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true));
746 }
747 
TEST_F(TransformerTest,TransformerEncoder)748 TEST_F(TransformerTest, TransformerEncoder) {
749   transformer_encoder_test_helper(
750       /*is_cuda=*/false, /*use_callable_activation=*/false);
751   transformer_encoder_test_helper(
752       /*is_cuda=*/false, /*use_callable_activation=*/true);
753 }
754 
TEST_F(TransformerTest,TransformerEncoder_CUDA)755 TEST_F(TransformerTest, TransformerEncoder_CUDA) {
756   transformer_encoder_test_helper(
757       /*is_cuda=*/true, /*use_callable_activation=*/false);
758   transformer_encoder_test_helper(
759       /*is_cuda=*/true, /*use_callable_activation=*/true);
760 }
761 
TEST_F(TransformerTest,PrettyPrintTransformerEncoderLayer)762 TEST_F(TransformerTest, PrettyPrintTransformerEncoderLayer) {
763   ASSERT_EQ(
764       c10::str(TransformerEncoderLayer(4, 2)),
765       "torch::nn::TransformerEncoderLayerImpl(\n"
766       "  (self_attn): torch::nn::MultiheadAttention(\n"
767       "    (out_proj): torch::nn::Linear(in_features=4, out_features=4, bias=true)\n"
768       "  )\n"
769       "  (linear1): torch::nn::Linear(in_features=4, out_features=2048, bias=true)\n"
770       "  (dropout): torch::nn::Dropout(p=0.1, inplace=false)\n"
771       "  (linear2): torch::nn::Linear(in_features=2048, out_features=4, bias=true)\n"
772       "  (norm1): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n"
773       "  (norm2): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n"
774       "  (dropout1): torch::nn::Dropout(p=0.1, inplace=false)\n"
775       "  (dropout2): torch::nn::Dropout(p=0.1, inplace=false)\n"
776       ")");
777 }
778 
TEST_F(TransformerTest,PrettyPrintTransformerEncoder)779 TEST_F(TransformerTest, PrettyPrintTransformerEncoder) {
780   LayerNorm norm = LayerNorm(LayerNormOptions({4}));
781   TransformerEncoderOptions options(
782       TransformerEncoderOptions(TransformerEncoderLayerOptions(4, 2), 2)
783           .norm(AnyModule(norm)));
784   ASSERT_EQ(
785       c10::str(TransformerEncoder(options)),
786       "torch::nn::TransformerEncoderImpl(\n"
787       "  (layers): torch::nn::ModuleList(\n"
788       "    (0): torch::nn::TransformerEncoderLayerImpl(\n"
789       "      (self_attn): torch::nn::MultiheadAttention(\n"
790       "        (out_proj): torch::nn::Linear(in_features=4, out_features=4, bias=true)\n"
791       "      )\n"
792       "      (linear1): torch::nn::Linear(in_features=4, out_features=2048, bias=true)\n"
793       "      (dropout): torch::nn::Dropout(p=0.1, inplace=false)\n"
794       "      (linear2): torch::nn::Linear(in_features=2048, out_features=4, bias=true)\n"
795       "      (norm1): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n"
796       "      (norm2): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n"
797       "      (dropout1): torch::nn::Dropout(p=0.1, inplace=false)\n"
798       "      (dropout2): torch::nn::Dropout(p=0.1, inplace=false)\n"
799       "    )\n"
800       "    (1): torch::nn::TransformerEncoderLayerImpl(\n"
801       "      (self_attn): torch::nn::MultiheadAttention(\n"
802       "        (out_proj): torch::nn::Linear(in_features=4, out_features=4, bias=true)\n"
803       "      )\n"
804       "      (linear1): torch::nn::Linear(in_features=4, out_features=2048, bias=true)\n"
805       "      (dropout): torch::nn::Dropout(p=0.1, inplace=false)\n"
806       "      (linear2): torch::nn::Linear(in_features=2048, out_features=4, bias=true)\n"
807       "      (norm1): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n"
808       "      (norm2): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n"
809       "      (dropout1): torch::nn::Dropout(p=0.1, inplace=false)\n"
810       "      (dropout2): torch::nn::Dropout(p=0.1, inplace=false)\n"
811       "    )\n"
812       "  )\n"
813       "  (norm): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n"
814       ")");
815 }
816 
TEST_F(TransformerTest,PrettyPrintTransformerDecoderLayer)817 TEST_F(TransformerTest, PrettyPrintTransformerDecoderLayer) {
818   ASSERT_EQ(
819       c10::str(TransformerDecoderLayer(4, 2)),
820       "torch::nn::TransformerDecoderLayerImpl(\n"
821       "  (self_attn): torch::nn::MultiheadAttention(\n"
822       "    (out_proj): torch::nn::Linear(in_features=4, out_features=4, bias=true)\n"
823       "  )\n"
824       "  (multihead_attn): torch::nn::MultiheadAttention(\n"
825       "    (out_proj): torch::nn::Linear(in_features=4, out_features=4, bias=true)\n"
826       "  )\n"
827       "  (linear1): torch::nn::Linear(in_features=4, out_features=2048, bias=true)\n"
828       "  (dropout): torch::nn::Dropout(p=0.1, inplace=false)\n"
829       "  (linear2): torch::nn::Linear(in_features=2048, out_features=4, bias=true)\n"
830       "  (norm1): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n"
831       "  (norm2): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n"
832       "  (norm3): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n"
833       "  (dropout1): torch::nn::Dropout(p=0.1, inplace=false)\n"
834       "  (dropout2): torch::nn::Dropout(p=0.1, inplace=false)\n"
835       "  (dropout3): torch::nn::Dropout(p=0.1, inplace=false)\n"
836       ")");
837 }
838 
transformer_decoder_test_helper(bool is_cuda,bool use_callable_activation)839 void transformer_decoder_test_helper(
840     bool is_cuda,
841     bool use_callable_activation) {
842   // this is a deterministic test for TransformerDecoder
843   torch::Device device = is_cuda ? torch::kCUDA : torch::kCPU;
844   torch::TensorOptions tensor_options =
845       torch::TensorOptions().dtype(torch::kFloat32).device(device);
846 
847   TransformerDecoderLayer decoder_layer =
848       get_a_test_layer<TransformerDecoderLayer, TransformerDecoderLayerOptions>(
849           tensor_options, use_callable_activation);
850 
851   TransformerDecoder model(TransformerDecoderOptions(decoder_layer, 1));
852   if (is_cuda) {
853     model->to(torch::kCUDA);
854   }
855 
856   torch::Tensor decoder_input =
857       torch::tensor({{{20, 30, 40, 50}}}, tensor_options);
858   torch::Tensor memory_input =
859       torch::tensor({{{60, 70, 80, 90}}}, tensor_options);
860   torch::Tensor result = model(decoder_input, memory_input).detach();
861   torch::Tensor ref_output = torch::tensor(
862       {{{2.314351, 0.094805, -0.671322, 0.101977}}}, tensor_options);
863   ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
864   ASSERT_TRUE(torch::allclose(
865       result,
866       ref_output,
867       1e-7,
868       1e-5,
869       /*equal_nan=*/true));
870 
871   // deterministic input
872   decoder_input =
873       torch::tensor({{{9, 10, 11, 12}}, {{11, 12, 13, 14}}}, tensor_options);
874   memory_input = torch::tensor({{{1, 2, 3, 4}}}, tensor_options);
875   result = model(decoder_input, memory_input).detach();
876   ref_output = torch::tensor(
877       {{{2.422245, 0.051716, -0.606338, -0.024756}},
878        {{2.422245, 0.051716, -0.606338, -0.024756}}},
879       tensor_options);
880   ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
881   ASSERT_TRUE(torch::allclose(
882       result,
883       ref_output,
884       1e-7,
885       1e-5,
886       /*equal_nan=*/true));
887 
888   // deterministic input
889   decoder_input =
890       torch::tensor({{{1, 2, 3, 4}}, {{5, 6, 7, 8}}}, tensor_options);
891   memory_input =
892       torch::tensor({{{9, 10, 11, 12}}, {{11, 12, 13, 14}}}, tensor_options);
893   result = model(decoder_input, memory_input).detach();
894   ref_output = torch::tensor(
895       {{{2.343536, 0.085561, -0.654954, 0.074991}},
896        {{2.343536, 0.085561, -0.654954, 0.074991}}},
897       tensor_options);
898   ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
899   ASSERT_TRUE(torch::allclose(
900       result,
901       ref_output,
902       1e-7,
903       1e-5,
904       /*equal_nan=*/true));
905 
906   // deterministic input
907   decoder_input = torch::tensor(
908       {{{0.4517, 0.6793, 0.5313, 0.0034}, {0.2678, 0.3677, 0.4459, 0.7166}},
909        {{0.8100, 0.3716, 0.4096, 0.1976}, {0.6958, 0.8844, 0.6081, 0.8315}},
910        {{0.0494, 0.9343, 0.5955, 0.3830}, {0.5404, 0.3464, 0.9378, 0.6200}}},
911       tensor_options);
912   memory_input = torch::tensor(
913       {{{0.7462, 0.6653, 0.5679, 0.4891}, {0.5387, 0.1655, 0.3565, 0.0471}},
914        {{0.8335, 0.2799, 0.5031, 0.2947}, {0.1402, 0.0318, 0.7636, 0.1346}},
915        {{0.6333, 0.9344, 0.1376, 0.9938}, {0.8924, 0.2872, 0.6692, 0.2944}},
916        {{0.9897, 0.6915, 0.3154, 0.1733}, {0.8645, 0.3513, 0.3064, 0.0767}},
917        {{0.8117, 0.2366, 0.4838, 0.7881}, {0.3718, 0.4945, 0.9511, 0.0864}}},
918       tensor_options);
919   result = model(decoder_input, memory_input).detach();
920   ref_output = torch::tensor(
921       {{{2.430065, 0.027862, -0.601136, -0.073096},
922         {2.431935, 0.028907, -0.599809, -0.072488}},
923        {{2.428457, 0.027053, -0.602275, -0.073462},
924         {2.431970, 0.029387, -0.599789, -0.071621}},
925        {{2.431934, 0.028196, -0.599802, -0.073809},
926         {2.432306, 0.028858, -0.599542, -0.072846}}},
927       tensor_options);
928   ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
929   ASSERT_TRUE(torch::allclose(
930       result,
931       ref_output,
932       1e-7,
933       1e-5,
934       /*equal_nan=*/true));
935 
936   // key_padding_mask
937   torch::Tensor t_mask = {};
938   torch::Tensor m_mask = {};
939   torch::Tensor key_padding_mask = torch::zeros({2, 3}, tensor_options) == 1;
940   result = model(decoder_input, memory_input, t_mask, m_mask, key_padding_mask)
941                .detach();
942   ref_output = torch::tensor(
943       {{{2.430065, 0.027862, -0.601136, -0.073096},
944         {2.431935, 0.028907, -0.599809, -0.072488}},
945        {{2.428457, 0.027053, -0.602275, -0.073462},
946         {2.431970, 0.029387, -0.599789, -0.071621}},
947        {{2.431934, 0.028196, -0.599802, -0.073809},
948         {2.432306, 0.028858, -0.599542, -0.072846}}},
949       tensor_options);
950   ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
951   ASSERT_TRUE(torch::allclose(
952       result,
953       ref_output,
954       1e-7,
955       1e-5,
956       /*equal_nan=*/true));
957 
958   // key_padding_mask
959   key_padding_mask[0][2] = 1;
960   key_padding_mask[1][1] = 1;
961   key_padding_mask[1][2] = 1;
962   result = model(decoder_input, memory_input, t_mask, m_mask, key_padding_mask)
963                .detach();
964   ref_output = torch::tensor(
965       {{{2.430025, 0.027643, -0.601164, -0.073476},
966         {2.4323, 0.029375, -0.599553, -0.071881}},
967        {{2.428523, 0.026838, -0.602226, -0.07391},
968         {2.432634, 0.029842, -0.599318, -0.071253}},
969        {{2.432278, 0.028152, -0.599555, -0.074139},
970         {2.432659, 0.029244, -0.599294, -0.072382}}},
971       tensor_options);
972   ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
973   ASSERT_TRUE(torch::allclose(
974       result,
975       ref_output,
976       1e-7,
977       1e-5,
978       /*equal_nan=*/true));
979 
980   // memory_key_padding_mask
981   torch::Tensor t_key_padding_mask = {};
982   key_padding_mask = torch::zeros({2, 5}, tensor_options) == 1;
983   result = model(
984                decoder_input,
985                memory_input,
986                t_mask,
987                m_mask,
988                t_key_padding_mask,
989                key_padding_mask)
990                .detach();
991   ref_output = torch::tensor(
992       {{{2.430065, 0.027862, -0.601136, -0.073096},
993         {2.431935, 0.028907, -0.599809, -0.072488}},
994        {{2.428457, 0.027053, -0.602275, -0.073462},
995         {2.431970, 0.029387, -0.599789, -0.071621}},
996        {{2.431934, 0.028196, -0.599802, -0.073809},
997         {2.432306, 0.028858, -0.599542, -0.072846}}},
998       tensor_options);
999   ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
1000   ASSERT_TRUE(torch::allclose(
1001       result,
1002       ref_output,
1003       1e-7,
1004       1e-5,
1005       /*equal_nan=*/true));
1006 
1007   // memory_key_padding_mask
1008   key_padding_mask[0][4] = 1;
1009   key_padding_mask[1][3] = 1;
1010   key_padding_mask[1][4] = 1;
1011   result = model(
1012                decoder_input,
1013                memory_input,
1014                t_mask,
1015                m_mask,
1016                t_key_padding_mask,
1017                key_padding_mask)
1018                .detach();
1019   ref_output = torch::tensor(
1020       {{{2.429757, 0.027358, -0.601351, -0.073816},
1021         {2.432692, 0.028583, -0.599263, -0.073634}},
1022        {{2.428247, 0.02662, -0.602419, -0.074123},
1023         {2.432657, 0.029055, -0.599293, -0.072732}},
1024        {{2.431515, 0.027687, -0.600096, -0.074459},
1025         {2.433075, 0.028543, -0.598987, -0.073985}}},
1026       tensor_options);
1027   ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
1028   ASSERT_TRUE(torch::allclose(
1029       result,
1030       ref_output,
1031       1e-7,
1032       1e-5,
1033       /*equal_nan=*/true));
1034 
1035   // multiple layers no norm
1036   model = TransformerDecoder(TransformerDecoderOptions(decoder_layer, 2));
1037   if (is_cuda) {
1038     model->to(torch::kCUDA);
1039   }
1040 
1041   decoder_input = torch::tensor({{{20, 30, 40, 50}}}, tensor_options);
1042   memory_input = torch::tensor({{{60, 70, 80, 90}}}, tensor_options);
1043   result = model(decoder_input, memory_input).detach();
1044   ref_output = torch::tensor(
1045       {{{2.31316, 0.0950293, -0.671995, 0.102802}}}, tensor_options);
1046   ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
1047   ASSERT_TRUE(torch::allclose(
1048       result,
1049       ref_output,
1050       1e-7,
1051       1e-5,
1052       /*equal_nan=*/true));
1053 
1054   // multiple layers no norm
1055   model = TransformerDecoder(TransformerDecoderOptions(decoder_layer, 6));
1056   if (is_cuda) {
1057     model->to(torch::kCUDA);
1058   }
1059   // deterministic input
1060   decoder_input = torch::tensor(
1061       {{{0.4517, 0.6793, 0.5313, 0.0034}, {0.2678, 0.3677, 0.4459, 0.7166}},
1062        {{0.8100, 0.3716, 0.4096, 0.1976}, {0.6958, 0.8844, 0.6081, 0.8315}},
1063        {{0.0494, 0.9343, 0.5955, 0.3830}, {0.5404, 0.3464, 0.9378, 0.6200}}},
1064       tensor_options);
1065   memory_input = torch::tensor(
1066       {{{0.7462, 0.6653, 0.5679, 0.4891}, {0.5387, 0.1655, 0.3565, 0.0471}},
1067        {{0.8335, 0.2799, 0.5031, 0.2947}, {0.1402, 0.0318, 0.7636, 0.1346}},
1068        {{0.6333, 0.9344, 0.1376, 0.9938}, {0.8924, 0.2872, 0.6692, 0.2944}},
1069        {{0.9897, 0.6915, 0.3154, 0.1733}, {0.8645, 0.3513, 0.3064, 0.0767}},
1070        {{0.8117, 0.2366, 0.4838, 0.7881}, {0.3718, 0.4945, 0.9511, 0.0864}}},
1071       tensor_options);
1072   result = model(decoder_input, memory_input).detach();
1073   ref_output = torch::tensor(
1074       {{{2.42794, 0.026164, -0.60263, -0.0747591},
1075         {2.43113, 0.0279516, -0.600376, -0.0736896}},
1076        {{2.42794, 0.026164, -0.60263, -0.0747591},
1077         {2.43113, 0.0279516, -0.600376, -0.0736896}},
1078        {{2.42794, 0.026164, -0.60263, -0.0747591},
1079         {2.43113, 0.0279516, -0.600376, -0.0736896}}},
1080       tensor_options);
1081   ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
1082   ASSERT_TRUE(torch::allclose(
1083       result,
1084       ref_output,
1085       1e-7,
1086       1e-5,
1087       /*equal_nan=*/true));
1088 
1089   // multiple layers with norm
1090   LayerNorm norm(LayerNormOptions({decoder_layer.get()->options.d_model()}));
1091   model = TransformerDecoder(
1092       TransformerDecoderOptions(decoder_layer, 2).norm(AnyModule(norm)));
1093   if (is_cuda) {
1094     model->to(torch::kCUDA);
1095   }
1096 
1097   decoder_input = torch::tensor({{{20, 30, 40, 50}}}, tensor_options);
1098   memory_input = torch::tensor({{{60, 70, 80, 90}}}, tensor_options);
1099   result = model(decoder_input, memory_input).detach();
1100   ref_output = torch::tensor(
1101       {{{1.66166, -0.326986, -1.01466, -0.320017}}}, tensor_options);
1102   ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
1103   ASSERT_TRUE(torch::allclose(
1104       result,
1105       ref_output,
1106       1e-7,
1107       1e-5,
1108       /*equal_nan=*/true));
1109 
1110   // multiple layers with norm
1111   model = TransformerDecoder(
1112       TransformerDecoderOptions(decoder_layer, 6).norm(AnyModule(norm)));
1113   if (is_cuda) {
1114     model->to(torch::kCUDA);
1115   }
1116   // deterministic input
1117   decoder_input = torch::tensor(
1118       {{{0.4517, 0.6793, 0.5313, 0.0034}, {0.2678, 0.3677, 0.4459, 0.7166}},
1119        {{0.8100, 0.3716, 0.4096, 0.1976}, {0.6958, 0.8844, 0.6081, 0.8315}},
1120        {{0.0494, 0.9343, 0.5955, 0.3830}, {0.5404, 0.3464, 0.9378, 0.6200}}},
1121       tensor_options);
1122   memory_input = torch::tensor(
1123       {{{0.7462, 0.6653, 0.5679, 0.4891}, {0.5387, 0.1655, 0.3565, 0.0471}},
1124        {{0.8335, 0.2799, 0.5031, 0.2947}, {0.1402, 0.0318, 0.7636, 0.1346}},
1125        {{0.6333, 0.9344, 0.1376, 0.9938}, {0.8924, 0.2872, 0.6692, 0.2944}},
1126        {{0.9897, 0.6915, 0.3154, 0.1733}, {0.8645, 0.3513, 0.3064, 0.0767}},
1127        {{0.8117, 0.2366, 0.4838, 0.7881}, {0.3718, 0.4945, 0.9511, 0.0864}}},
1128       tensor_options);
1129   result = model(decoder_input, memory_input).detach();
1130   ref_output = torch::tensor(
1131       {{{1.69559, -0.357291, -0.894741, -0.443553},
1132         {1.69571, -0.357363, -0.894154, -0.444196}},
1133        {{1.69559, -0.357291, -0.894741, -0.443553},
1134         {1.69571, -0.357363, -0.894154, -0.444196}},
1135        {{1.69559, -0.357291, -0.894741, -0.443553},
1136         {1.69571, -0.357363, -0.894154, -0.444196}}},
1137       tensor_options);
1138   ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
1139   ASSERT_TRUE(torch::allclose(
1140       result,
1141       ref_output,
1142       1e-7,
1143       1e-5,
1144       /*equal_nan=*/true));
1145 
1146   // gelu activation test cases
1147   decoder_layer.get()->options.activation(torch::kGELU);
1148   model = TransformerDecoder(TransformerDecoderOptions(decoder_layer, 1));
1149   if (is_cuda) {
1150     model->to(torch::kCUDA);
1151   }
1152 
1153   // deterministic input
1154   decoder_input = torch::tensor({{{20, 30, 40, 50}}}, tensor_options);
1155   memory_input = torch::tensor({{{60, 70, 80, 90}}}, tensor_options);
1156   result = model(decoder_input, memory_input).detach();
1157   ref_output = torch::tensor(
1158       {{{2.306435, 0.095946, -0.675796, 0.10687}}}, tensor_options);
1159   ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
1160   ASSERT_TRUE(torch::allclose(
1161       result,
1162       ref_output,
1163       1e-7,
1164       1e-5,
1165       /*equal_nan=*/true));
1166 
1167   // deterministic input
1168   decoder_input =
1169       torch::tensor({{{9, 10, 11, 12}}, {{11, 12, 13, 14}}}, tensor_options);
1170   memory_input = torch::tensor({{{1, 2, 3, 4}}}, tensor_options);
1171   result = model(decoder_input, memory_input).detach();
1172   ref_output = torch::tensor(
1173       {{{2.415448, 0.054389, -0.610932, -0.0156613}},
1174        {{2.415448, 0.054389, -0.610932, -0.0156613}}},
1175       tensor_options);
1176   ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
1177   ASSERT_TRUE(torch::allclose(
1178       result,
1179       ref_output,
1180       1e-7,
1181       1e-5,
1182       /*equal_nan=*/true));
1183 
1184   // deterministic input
1185   decoder_input =
1186       torch::tensor({{{1, 2, 3, 4}}, {{5, 6, 7, 8}}}, tensor_options);
1187   memory_input =
1188       torch::tensor({{{9, 10, 11, 12}}, {{11, 12, 13, 14}}}, tensor_options);
1189   result = model(decoder_input, memory_input).detach();
1190   ref_output = torch::tensor(
1191       {{{2.338531, 0.087709, -0.65776, 0.080646}},
1192        {{2.338531, 0.087709, -0.65776, 0.080646}}},
1193       tensor_options);
1194   ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
1195   ASSERT_TRUE(torch::allclose(
1196       result,
1197       ref_output,
1198       1e-7,
1199       1e-5,
1200       /*equal_nan=*/true));
1201 
1202   // deterministic input
1203   decoder_input = torch::tensor(
1204       {{{0.4517, 0.6793, 0.5313, 0.0034}, {0.2678, 0.3677, 0.4459, 0.7166}},
1205        {{0.8100, 0.3716, 0.4096, 0.1976}, {0.6958, 0.8844, 0.6081, 0.8315}},
1206        {{0.0494, 0.9343, 0.5955, 0.3830}, {0.5404, 0.3464, 0.9378, 0.6200}}},
1207       tensor_options);
1208   memory_input = torch::tensor(
1209       {{{0.7462, 0.6653, 0.5679, 0.4891}, {0.5387, 0.1655, 0.3565, 0.0471}},
1210        {{0.8335, 0.2799, 0.5031, 0.2947}, {0.1402, 0.0318, 0.7636, 0.1346}},
1211        {{0.6333, 0.9344, 0.1376, 0.9938}, {0.8924, 0.2872, 0.6692, 0.2944}},
1212        {{0.9897, 0.6915, 0.3154, 0.1733}, {0.8645, 0.3513, 0.3064, 0.0767}},
1213        {{0.8117, 0.2366, 0.4838, 0.7881}, {0.3718, 0.4945, 0.9511, 0.0864}}},
1214       tensor_options);
1215   result = model(decoder_input, memory_input).detach();
1216   ref_output = torch::tensor(
1217       {{{2.42049104, 0.03443088, -0.60793706, -0.05436271},
1218         {2.42210631, 0.03546578, -0.60679895, -0.05357488}},
1219        {{2.41907674, 0.0336104, -0.60892977, -0.05490462},
1220         {2.42216881, 0.03586554, -0.6067524, -0.05289126}},
1221        {{2.42205716, 0.03488046, -0.60683681, -0.05460596},
1222         {2.42240309, 0.0354595, -0.60659063, -0.05378816}}},
1223       tensor_options);
1224   ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
1225   ASSERT_TRUE(torch::allclose(
1226       result,
1227       ref_output,
1228       1e-7,
1229       1e-5,
1230       /*equal_nan=*/true));
1231 
1232   // Multiple layers no norm
1233   model = TransformerDecoder(TransformerDecoderOptions(decoder_layer, 6));
1234   if (is_cuda) {
1235     model->to(torch::kCUDA);
1236   }
1237   decoder_input = torch::tensor(
1238       {{{0.4517, 0.6793, 0.5313, 0.0034}, {0.2678, 0.3677, 0.4459, 0.7166}},
1239        {{0.8100, 0.3716, 0.4096, 0.1976}, {0.6958, 0.8844, 0.6081, 0.8315}},
1240        {{0.0494, 0.9343, 0.5955, 0.3830}, {0.5404, 0.3464, 0.9378, 0.6200}}},
1241       tensor_options);
1242   memory_input = torch::tensor(
1243       {{{0.7462, 0.6653, 0.5679, 0.4891}, {0.5387, 0.1655, 0.3565, 0.0471}},
1244        {{0.8335, 0.2799, 0.5031, 0.2947}, {0.1402, 0.0318, 0.7636, 0.1346}},
1245        {{0.6333, 0.9344, 0.1376, 0.9938}, {0.8924, 0.2872, 0.6692, 0.2944}},
1246        {{0.9897, 0.6915, 0.3154, 0.1733}, {0.8645, 0.3513, 0.3064, 0.0767}},
1247        {{0.8117, 0.2366, 0.4838, 0.7881}, {0.3718, 0.4945, 0.9511, 0.0864}}},
1248       tensor_options);
1249   result = model(decoder_input, memory_input).detach();
1250   ref_output = torch::tensor(
1251       {{{2.41859, 0.0328114, -0.609269, -0.0560386},
1252         {2.42138, 0.034598, -0.607316, -0.0546574}},
1253        {{2.41859, 0.0328114, -0.609269, -0.0560386},
1254         {2.42138, 0.034598, -0.607316, -0.0546574}},
1255        {{2.41859, 0.0328114, -0.609269, -0.0560386},
1256         {2.42138, 0.034598, -0.607316, -0.0546574}}},
1257       tensor_options);
1258   ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
1259   ASSERT_TRUE(torch::allclose(
1260       result,
1261       ref_output,
1262       1e-7,
1263       1e-5,
1264       /*equal_nan=*/true));
1265 
1266   // Multiple layers with norm
1267   norm = LayerNorm(LayerNormOptions({decoder_layer.get()->options.d_model()}));
1268   model = TransformerDecoder(
1269       TransformerDecoderOptions(decoder_layer, 6).norm(AnyModule(norm)));
1270   if (is_cuda) {
1271     model->to(torch::kCUDA);
1272   }
1273 
1274   decoder_input = torch::tensor(
1275       {{{0.4517, 0.6793, 0.5313, 0.0034}, {0.2678, 0.3677, 0.4459, 0.7166}},
1276        {{0.8100, 0.3716, 0.4096, 0.1976}, {0.6958, 0.8844, 0.6081, 0.8315}},
1277        {{0.0494, 0.9343, 0.5955, 0.3830}, {0.5404, 0.3464, 0.9378, 0.6200}}},
1278       tensor_options);
1279   memory_input = torch::tensor(
1280       {{{0.7462, 0.6653, 0.5679, 0.4891}, {0.5387, 0.1655, 0.3565, 0.0471}},
1281        {{0.8335, 0.2799, 0.5031, 0.2947}, {0.1402, 0.0318, 0.7636, 0.1346}},
1282        {{0.6333, 0.9344, 0.1376, 0.9938}, {0.8924, 0.2872, 0.6692, 0.2944}},
1283        {{0.9897, 0.6915, 0.3154, 0.1733}, {0.8645, 0.3513, 0.3064, 0.0767}},
1284        {{0.8117, 0.2366, 0.4838, 0.7881}, {0.3718, 0.4945, 0.9511, 0.0864}}},
1285       tensor_options);
1286   result = model(decoder_input, memory_input).detach();
1287   ref_output = torch::tensor(
1288       {{{1.69298, -0.355163, -0.906375, -0.431439},
1289         {1.69305, -0.355195, -0.906062, -0.431791}},
1290        {{1.69298, -0.355163, -0.906375, -0.431439},
1291         {1.69305, -0.355195, -0.906062, -0.431791}},
1292        {{1.69298, -0.355163, -0.906375, -0.431439},
1293         {1.69305, -0.355195, -0.906062, -0.431791}}},
1294       tensor_options);
1295   ASSERT_EQ(result.sizes().size(), ref_output.sizes().size());
1296   ASSERT_TRUE(torch::allclose(
1297       result,
1298       ref_output,
1299       1e-7,
1300       1e-5,
1301       /*equal_nan=*/true));
1302 }
1303 
TEST_F(TransformerTest,TransformerDecoder)1304 TEST_F(TransformerTest, TransformerDecoder) {
1305   transformer_decoder_test_helper(
1306       /*is_cuda=*/false, /*use_callable_activation=*/false);
1307   transformer_decoder_test_helper(
1308       /*is_cuda=*/false, /*use_callable_activation=*/true);
1309 }
1310 
TEST_F(TransformerTest,TransformerDecoder_CUDA)1311 TEST_F(TransformerTest, TransformerDecoder_CUDA) {
1312   transformer_decoder_test_helper(
1313       /*is_cuda=*/true, /*use_callable_activation=*/false);
1314   transformer_decoder_test_helper(
1315       /*is_cuda=*/true, /*use_callable_activation=*/true);
1316 }
1317 
TEST_F(TransformerTest,PrettyPrintTransformerDecoder)1318 TEST_F(TransformerTest, PrettyPrintTransformerDecoder) {
1319   LayerNorm norm = LayerNorm(LayerNormOptions({4}));
1320   TransformerDecoderOptions options(
1321       TransformerDecoderOptions(TransformerDecoderLayerOptions(4, 2), 2)
1322           .norm(AnyModule(norm)));
1323   ASSERT_EQ(
1324       c10::str(TransformerDecoder(options)),
1325       "torch::nn::TransformerDecoderImpl(\n"
1326       "  (layers): torch::nn::ModuleList(\n"
1327       "    (0): torch::nn::TransformerDecoderLayerImpl(\n"
1328       "      (self_attn): torch::nn::MultiheadAttention(\n"
1329       "        (out_proj): torch::nn::Linear(in_features=4, out_features=4, bias=true)\n"
1330       "      )\n"
1331       "      (multihead_attn): torch::nn::MultiheadAttention(\n"
1332       "        (out_proj): torch::nn::Linear(in_features=4, out_features=4, bias=true)\n"
1333       "      )\n"
1334       "      (linear1): torch::nn::Linear(in_features=4, out_features=2048, bias=true)\n"
1335       "      (dropout): torch::nn::Dropout(p=0.1, inplace=false)\n"
1336       "      (linear2): torch::nn::Linear(in_features=2048, out_features=4, bias=true)\n"
1337       "      (norm1): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n"
1338       "      (norm2): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n"
1339       "      (norm3): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n"
1340       "      (dropout1): torch::nn::Dropout(p=0.1, inplace=false)\n"
1341       "      (dropout2): torch::nn::Dropout(p=0.1, inplace=false)\n"
1342       "      (dropout3): torch::nn::Dropout(p=0.1, inplace=false)\n"
1343       "    )\n"
1344       "    (1): torch::nn::TransformerDecoderLayerImpl(\n"
1345       "      (self_attn): torch::nn::MultiheadAttention(\n"
1346       "        (out_proj): torch::nn::Linear(in_features=4, out_features=4, bias=true)\n"
1347       "      )\n"
1348       "      (multihead_attn): torch::nn::MultiheadAttention(\n"
1349       "        (out_proj): torch::nn::Linear(in_features=4, out_features=4, bias=true)\n"
1350       "      )\n"
1351       "      (linear1): torch::nn::Linear(in_features=4, out_features=2048, bias=true)\n"
1352       "      (dropout): torch::nn::Dropout(p=0.1, inplace=false)\n"
1353       "      (linear2): torch::nn::Linear(in_features=2048, out_features=4, bias=true)\n"
1354       "      (norm1): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n"
1355       "      (norm2): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n"
1356       "      (norm3): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n"
1357       "      (dropout1): torch::nn::Dropout(p=0.1, inplace=false)\n"
1358       "      (dropout2): torch::nn::Dropout(p=0.1, inplace=false)\n"
1359       "      (dropout3): torch::nn::Dropout(p=0.1, inplace=false)\n"
1360       "    )\n"
1361       "  )\n"
1362       "  (norm): torch::nn::LayerNorm([4], eps=1e-05, elementwise_affine=true)\n"
1363       ")");
1364 }
1365 
transformer_test_helper(bool is_cuda,bool use_callable_activation)1366 void transformer_test_helper(bool is_cuda, bool use_callable_activation) {
1367   // this is a deterministic test for Transformere
1368   torch::Device device = is_cuda ? torch::kCUDA : torch::kCPU;
1369   torch::TensorOptions tensor_options =
1370       torch::TensorOptions().dtype(torch::kFloat32).device(device);
1371 
1372   // transformer created encoder/decoder
1373   auto options = TransformerOptions()
1374                      .d_model(4)
1375                      .nhead(2)
1376                      .num_encoder_layers(2)
1377                      .num_decoder_layers(1)
1378                      .dim_feedforward(16)
1379                      .dropout(0.0)
1380                      .activation(torch::kReLU);
1381   if (use_callable_activation) {
1382     options.activation(
1383         [&](const torch::Tensor& t) { return torch::nn::functional::relu(t); });
1384   }
1385   Transformer model(options);
1386 
1387   set_parameter_to_constants<Transformer>(model, tensor_options);
1388   if (tensor_options.device() == torch::kCUDA) {
1389     model->to(torch::kCUDA);
1390   }
1391 
1392   // transformer with customized encoder/decoder
1393   LayerNorm enorm(LayerNormOptions({4}));
1394   TransformerEncoder encoder(
1395       TransformerEncoderOptions(
1396           TransformerEncoderLayerOptions(4, 2).dim_feedforward(16).dropout(0.0),
1397           2)
1398           .norm(AnyModule(enorm)));
1399 
1400   LayerNorm dnorm(LayerNormOptions({4}));
1401   TransformerDecoder decoder(
1402       TransformerDecoderOptions(
1403           TransformerDecoderLayerOptions(4, 2).dim_feedforward(16).dropout(0.0),
1404           1)
1405           .norm(AnyModule(dnorm)));
1406 
1407   Transformer model_cus(TransformerOptions()
1408                             .d_model(4)
1409                             .nhead(2)
1410                             .custom_encoder(AnyModule(encoder))
1411                             .custom_decoder(AnyModule(decoder)));
1412 
1413   set_parameter_to_constants<Transformer>(model_cus, tensor_options);
1414   if (tensor_options.device() == torch::kCUDA) {
1415     model_cus->to(torch::kCUDA);
1416   }
1417 
1418   // test cases
1419   torch::Tensor src = torch::tensor(
1420       {{{1.0, 2.0, 3.0, 4.0}, {5.0, 6.0, 7.0, 8.0}},
1421        {{9.0, 10.0, 11.0, 12.0}, {13.0, 14.0, 15.0, 16.0}},
1422        {{17.0, 18.0, 19.0, 20.0}, {21.0, 22.0, 23.0, 24.0}}},
1423       tensor_options);
1424 
1425   torch::Tensor tgt = torch::tensor(
1426       {{{1.0, 2.0, 3.0, 4.0}, {5.0, 6.0, 7.0, 8.0}},
1427        {{9.0, 10.0, 11.0, 12.0}, {13.0, 14.0, 15.0, 16.0}}},
1428       tensor_options);
1429 
1430   torch::Tensor ref_output = torch::tensor(
1431       {{{2.695875, 0.347114, -0.044355, -0.549541},
1432         {2.696091, 0.347015, -0.044770, -0.548522}},
1433        {{2.695875, 0.347114, -0.044355, -0.549541},
1434         {2.696091, 0.347015, -0.044770, -0.548522}}},
1435       tensor_options);
1436   torch::Tensor result = model(src, tgt);
1437   torch::Tensor result_cus = model_cus(src, tgt);
1438   ASSERT_EQ(result.sizes(), ref_output.sizes());
1439   ASSERT_TRUE(result.equal(result_cus));
1440   ASSERT_TRUE(
1441       torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true));
1442 
1443   torch::Tensor src_mask =
1444       Transformer::Impl::generate_square_subsequent_mask(src.size(0))
1445           .to(tensor_options);
1446   ref_output = torch::tensor(
1447       {{{2.695875, 0.347114, -0.044355, -0.549541},
1448         {2.696091, 0.347015, -0.044770, -0.548522}},
1449        {{2.695875, 0.347114, -0.044355, -0.549541},
1450         {2.696091, 0.347015, -0.044770, -0.548522}}},
1451       tensor_options);
1452   result = model(src, tgt, src_mask);
1453   result_cus = model_cus(src, tgt, src_mask);
1454   ASSERT_EQ(result.sizes(), ref_output.sizes());
1455   ASSERT_TRUE(result.equal(result_cus));
1456   ASSERT_TRUE(
1457       torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true));
1458 
1459   torch::Tensor tgt_key_padding_mask =
1460       torch::zeros({tgt.size(1), tgt.size(0)}, tensor_options) == 1;
1461   tgt_key_padding_mask[0][0] = 1;
1462   tgt_key_padding_mask[1][1] = 1;
1463   ref_output = torch::tensor(
1464       {{{2.696114, 0.347004, -0.044813, -0.548417},
1465         {2.696091, 0.347015, -0.044770, -0.548522}},
1466        {{2.696114, 0.347004, -0.044813, -0.548417},
1467         {2.696091, 0.347015, -0.044770, -0.548522}}},
1468       tensor_options);
1469   result = model(
1470       src,
1471       tgt,
1472       src_mask,
1473       torch::Tensor(),
1474       torch::Tensor(),
1475       torch::Tensor(),
1476       tgt_key_padding_mask);
1477   result_cus = model_cus(
1478       src,
1479       tgt,
1480       src_mask,
1481       torch::Tensor(),
1482       torch::Tensor(),
1483       torch::Tensor(),
1484       tgt_key_padding_mask);
1485   ASSERT_EQ(result.sizes(), ref_output.sizes());
1486   ASSERT_TRUE(result.equal(result_cus));
1487   ASSERT_TRUE(
1488       torch::allclose(result, ref_output, 1e-7, 1e-5, /*equal_nan=*/true));
1489 }
1490 
TEST_F(TransformerTest,Transformer)1491 TEST_F(TransformerTest, Transformer) {
1492   transformer_test_helper(/*is_cuda=*/false, /*use_callable_activation=*/false);
1493   transformer_test_helper(/*is_cuda=*/false, /*use_callable_activation=*/true);
1494 }
1495 
TEST_F(TransformerTest,Transformer_CUDA)1496 TEST_F(TransformerTest, Transformer_CUDA) {
1497   transformer_test_helper(/*is_cuda=*/true, /*use_callable_activation=*/false);
1498   transformer_test_helper(/*is_cuda=*/true, /*use_callable_activation=*/true);
1499 }
1500 
TEST_F(TransformerTest,TransformerArgsCorrectness)1501 TEST_F(TransformerTest, TransformerArgsCorrectness) {
1502   Transformer model(TransformerOptions()
1503                         .d_model(4)
1504                         .nhead(2)
1505                         .num_encoder_layers(2)
1506                         .num_decoder_layers(1)
1507                         .dim_feedforward(16)
1508                         .dropout(0.0)
1509                         .activation(torch::kReLU));
1510 
1511   torch::Tensor src = torch::randn({2, 3, 4});
1512   torch::Tensor tgt = torch::randn({3, 2, 4});
1513 
1514   ASSERT_THROWS_WITH(
1515       model(src, tgt), "src and tgt should have equal batch size");
1516 
1517   tgt = torch::randn({2, 3, 3});
1518   ASSERT_THROWS_WITH(
1519       model(src, tgt), "src and tgt should have same feature size as d_model");
1520 
1521   src = torch::randn({2, 3});
1522   ASSERT_THROWS_WITH(model(src, tgt), "src and tgt should have 3 dimensions");
1523 }
1524