xref: /aosp_15_r20/external/XNNPACK/test/subgraph-fp16.cc (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1*4bdc9457SAndroid Build Coastguard Worker // Copyright 2022 Google LLC
2*4bdc9457SAndroid Build Coastguard Worker //
3*4bdc9457SAndroid Build Coastguard Worker // This source code is licensed under the BSD-style license found in the
4*4bdc9457SAndroid Build Coastguard Worker // LICENSE file in the root directory of this source tree.
5*4bdc9457SAndroid Build Coastguard Worker 
6*4bdc9457SAndroid Build Coastguard Worker #include <array>
7*4bdc9457SAndroid Build Coastguard Worker 
8*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack.h>
9*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/node-type.h>
10*4bdc9457SAndroid Build Coastguard Worker 
11*4bdc9457SAndroid Build Coastguard Worker #include "subgraph-tester.h"
12*4bdc9457SAndroid Build Coastguard Worker #include <gtest/gtest.h>
13*4bdc9457SAndroid Build Coastguard Worker 
14*4bdc9457SAndroid Build Coastguard Worker namespace xnnpack {
15*4bdc9457SAndroid Build Coastguard Worker 
TEST(SUBGRAPH_FP16,value_both_external_output_and_input)16*4bdc9457SAndroid Build Coastguard Worker TEST(SUBGRAPH_FP16, value_both_external_output_and_input) {
17*4bdc9457SAndroid Build Coastguard Worker   auto tester = SubgraphTester(4);
18*4bdc9457SAndroid Build Coastguard Worker   std::array<size_t, 4> pre_paddings = {0,1,0,0};
19*4bdc9457SAndroid Build Coastguard Worker   std::array<size_t, 4> post_paddings = {0,1,0,0};
20*4bdc9457SAndroid Build Coastguard Worker   // external input[0]
21*4bdc9457SAndroid Build Coastguard Worker   //      /
22*4bdc9457SAndroid Build Coastguard Worker   // [constant pad]
23*4bdc9457SAndroid Build Coastguard Worker   //     /
24*4bdc9457SAndroid Build Coastguard Worker   //  external     dynamic[1]
25*4bdc9457SAndroid Build Coastguard Worker   //  output[2]     /
26*4bdc9457SAndroid Build Coastguard Worker   //           \   /
27*4bdc9457SAndroid Build Coastguard Worker   //           [add]
28*4bdc9457SAndroid Build Coastguard Worker   //             |
29*4bdc9457SAndroid Build Coastguard Worker   //         external
30*4bdc9457SAndroid Build Coastguard Worker   //         output[3]
31*4bdc9457SAndroid Build Coastguard Worker   tester
32*4bdc9457SAndroid Build Coastguard Worker       .AddInputTensorF32({1, 2, 2, 3}, 0)
33*4bdc9457SAndroid Build Coastguard Worker       .AddDynamicTensorF32({1, 1, 1, 3}, 1)
34*4bdc9457SAndroid Build Coastguard Worker       .AddOutputTensorF32({1, 4, 2, 3}, 2)
35*4bdc9457SAndroid Build Coastguard Worker       .AddOutputTensorF32({1, 4, 2, 3}, 3)
36*4bdc9457SAndroid Build Coastguard Worker       .AddConstantPad(pre_paddings.data(), post_paddings.data(), 0.0f, 0, 2)
37*4bdc9457SAndroid Build Coastguard Worker       .AddAddition(2, 1, 3)
38*4bdc9457SAndroid Build Coastguard Worker       .Optimize()
39*4bdc9457SAndroid Build Coastguard Worker       .RewriteForFp16();
40*4bdc9457SAndroid Build Coastguard Worker 
41*4bdc9457SAndroid Build Coastguard Worker   // After rewriting for FP16, the graph should look like this, with * indicating new operators and values created:
42*4bdc9457SAndroid Build Coastguard Worker   //
43*4bdc9457SAndroid Build Coastguard Worker   //   external input[0]
44*4bdc9457SAndroid Build Coastguard Worker   //        |
45*4bdc9457SAndroid Build Coastguard Worker   //    [convert]*
46*4bdc9457SAndroid Build Coastguard Worker   //        |
47*4bdc9457SAndroid Build Coastguard Worker   //     input[4]*
48*4bdc9457SAndroid Build Coastguard Worker   //       /
49*4bdc9457SAndroid Build Coastguard Worker   // [constant pad]
50*4bdc9457SAndroid Build Coastguard Worker   //     /
51*4bdc9457SAndroid Build Coastguard Worker   //   fp16 value[5]*
52*4bdc9457SAndroid Build Coastguard Worker   //    |       \
53*4bdc9457SAndroid Build Coastguard Worker   //  [convert]* \
54*4bdc9457SAndroid Build Coastguard Worker   //    |         \
55*4bdc9457SAndroid Build Coastguard Worker   //  external     \    dynamic[1] converted in-place
56*4bdc9457SAndroid Build Coastguard Worker   //  output[2]     \     /
57*4bdc9457SAndroid Build Coastguard Worker   //                 \   /
58*4bdc9457SAndroid Build Coastguard Worker   //                 [add]
59*4bdc9457SAndroid Build Coastguard Worker   //                   |
60*4bdc9457SAndroid Build Coastguard Worker   //                fp16 value[6]*
61*4bdc9457SAndroid Build Coastguard Worker   //                   |
62*4bdc9457SAndroid Build Coastguard Worker   //                [convert]*
63*4bdc9457SAndroid Build Coastguard Worker   //                   |
64*4bdc9457SAndroid Build Coastguard Worker   //               external
65*4bdc9457SAndroid Build Coastguard Worker   //               output[3]
66*4bdc9457SAndroid Build Coastguard Worker 
67*4bdc9457SAndroid Build Coastguard Worker   // We should have 3 convert nodes, one for external input, 2 for external
68*4bdc9457SAndroid Build Coastguard Worker   // outputs, so 5 in total, including the pad and add in the original graph.
69*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(tester.NumNodes(), 5);
70*4bdc9457SAndroid Build Coastguard Worker 
71*4bdc9457SAndroid Build Coastguard Worker   const xnn_node* output_convert_node = tester.Node(4);
72*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(output_convert_node->type, xnn_node_type_convert);
73*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(output_convert_node->compute_type, xnn_compute_type_fp16_to_fp32);
74*4bdc9457SAndroid Build Coastguard Worker 
75*4bdc9457SAndroid Build Coastguard Worker   // Check that Addition node refers to the FP16 value before conversion.
76*4bdc9457SAndroid Build Coastguard Worker   const xnn_node* addition_node = tester.Node(3);
77*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(addition_node->type, xnn_node_type_add2);
78*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(addition_node->inputs[0], 5);
79*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(addition_node->inputs[1], 1);
80*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(tester.Value(5)->datatype, xnn_datatype_fp16);
81*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(tester.Value(1)->datatype, xnn_datatype_fp16);
82*4bdc9457SAndroid Build Coastguard Worker 
83*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(tester.Node(2)->type, xnn_node_type_convert);
84*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(tester.Node(2)->compute_type, xnn_compute_type_fp16_to_fp32);
85*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(tester.Node(1)->type, xnn_node_type_static_constant_pad);
86*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(tester.Node(0)->type, xnn_node_type_convert);
87*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(tester.Node(0)->compute_type, xnn_compute_type_fp32_to_fp16);
88*4bdc9457SAndroid Build Coastguard Worker }
89*4bdc9457SAndroid Build Coastguard Worker 
90*4bdc9457SAndroid Build Coastguard Worker }  // namespace xnnpack
91