1 // Copyright 2022 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5
6 #include <array>
7
8 #include <xnnpack.h>
9 #include <xnnpack/node-type.h>
10
11 #include "subgraph-tester.h"
12 #include <gtest/gtest.h>
13
14 namespace xnnpack {
15
TEST(SUBGRAPH_FP16,value_both_external_output_and_input)16 TEST(SUBGRAPH_FP16, value_both_external_output_and_input) {
17 auto tester = SubgraphTester(4);
18 std::array<size_t, 4> pre_paddings = {0,1,0,0};
19 std::array<size_t, 4> post_paddings = {0,1,0,0};
20 // external input[0]
21 // /
22 // [constant pad]
23 // /
24 // external dynamic[1]
25 // output[2] /
26 // \ /
27 // [add]
28 // |
29 // external
30 // output[3]
31 tester
32 .AddInputTensorF32({1, 2, 2, 3}, 0)
33 .AddDynamicTensorF32({1, 1, 1, 3}, 1)
34 .AddOutputTensorF32({1, 4, 2, 3}, 2)
35 .AddOutputTensorF32({1, 4, 2, 3}, 3)
36 .AddConstantPad(pre_paddings.data(), post_paddings.data(), 0.0f, 0, 2)
37 .AddAddition(2, 1, 3)
38 .Optimize()
39 .RewriteForFp16();
40
41 // After rewriting for FP16, the graph should look like this, with * indicating new operators and values created:
42 //
43 // external input[0]
44 // |
45 // [convert]*
46 // |
47 // input[4]*
48 // /
49 // [constant pad]
50 // /
51 // fp16 value[5]*
52 // | \
53 // [convert]* \
54 // | \
55 // external \ dynamic[1] converted in-place
56 // output[2] \ /
57 // \ /
58 // [add]
59 // |
60 // fp16 value[6]*
61 // |
62 // [convert]*
63 // |
64 // external
65 // output[3]
66
67 // We should have 3 convert nodes, one for external input, 2 for external
68 // outputs, so 5 in total, including the pad and add in the original graph.
69 ASSERT_EQ(tester.NumNodes(), 5);
70
71 const xnn_node* output_convert_node = tester.Node(4);
72 ASSERT_EQ(output_convert_node->type, xnn_node_type_convert);
73 ASSERT_EQ(output_convert_node->compute_type, xnn_compute_type_fp16_to_fp32);
74
75 // Check that Addition node refers to the FP16 value before conversion.
76 const xnn_node* addition_node = tester.Node(3);
77 ASSERT_EQ(addition_node->type, xnn_node_type_add2);
78 ASSERT_EQ(addition_node->inputs[0], 5);
79 ASSERT_EQ(addition_node->inputs[1], 1);
80 ASSERT_EQ(tester.Value(5)->datatype, xnn_datatype_fp16);
81 ASSERT_EQ(tester.Value(1)->datatype, xnn_datatype_fp16);
82
83 ASSERT_EQ(tester.Node(2)->type, xnn_node_type_convert);
84 ASSERT_EQ(tester.Node(2)->compute_type, xnn_compute_type_fp16_to_fp32);
85 ASSERT_EQ(tester.Node(1)->type, xnn_node_type_static_constant_pad);
86 ASSERT_EQ(tester.Node(0)->type, xnn_node_type_convert);
87 ASSERT_EQ(tester.Node(0)->compute_type, xnn_compute_type_fp32_to_fp16);
88 }
89
90 } // namespace xnnpack
91