xref: /aosp_15_r20/external/XNNPACK/test/workspace.cc (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1*4bdc9457SAndroid Build Coastguard Worker // Copyright 2022 Google LLC
2*4bdc9457SAndroid Build Coastguard Worker //
3*4bdc9457SAndroid Build Coastguard Worker // This source code is licensed under the BSD-style license found in the
4*4bdc9457SAndroid Build Coastguard Worker // LICENSE file in the root directory of this source tree.
5*4bdc9457SAndroid Build Coastguard Worker 
6*4bdc9457SAndroid Build Coastguard Worker #include <algorithm>
7*4bdc9457SAndroid Build Coastguard Worker #include <array>
8*4bdc9457SAndroid Build Coastguard Worker #include <cstddef>
9*4bdc9457SAndroid Build Coastguard Worker #include <cstdint>
10*4bdc9457SAndroid Build Coastguard Worker #include <limits>
11*4bdc9457SAndroid Build Coastguard Worker #include <memory>
12*4bdc9457SAndroid Build Coastguard Worker 
13*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack.h>
14*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/subgraph.h>
15*4bdc9457SAndroid Build Coastguard Worker 
16*4bdc9457SAndroid Build Coastguard Worker #include <gtest/gtest.h>
17*4bdc9457SAndroid Build Coastguard Worker 
18*4bdc9457SAndroid Build Coastguard Worker namespace {
DefineGraphWithoutInternalTensors(xnn_subgraph_t * subgraph,std::array<size_t,4> dims)19*4bdc9457SAndroid Build Coastguard Worker void DefineGraphWithoutInternalTensors(xnn_subgraph_t* subgraph, std::array<size_t, 4> dims)
20*4bdc9457SAndroid Build Coastguard Worker {
21*4bdc9457SAndroid Build Coastguard Worker   xnn_create_subgraph(/*external_value_ids=*/0, /*flags=*/0, subgraph);
22*4bdc9457SAndroid Build Coastguard Worker   uint32_t input_id = XNN_INVALID_VALUE_ID;
23*4bdc9457SAndroid Build Coastguard Worker   xnn_define_tensor_value(
24*4bdc9457SAndroid Build Coastguard Worker     *subgraph, xnn_datatype_fp32, dims.size(), dims.data(), nullptr, XNN_INVALID_VALUE_ID,
25*4bdc9457SAndroid Build Coastguard Worker     XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id);
26*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(input_id, XNN_INVALID_VALUE_ID);
27*4bdc9457SAndroid Build Coastguard Worker 
28*4bdc9457SAndroid Build Coastguard Worker   uint32_t output_id = XNN_INVALID_VALUE_ID;
29*4bdc9457SAndroid Build Coastguard Worker   xnn_define_tensor_value(
30*4bdc9457SAndroid Build Coastguard Worker     *subgraph, xnn_datatype_fp32, dims.size(), dims.data(), nullptr, XNN_INVALID_VALUE_ID,
31*4bdc9457SAndroid Build Coastguard Worker     XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id);
32*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(output_id, XNN_INVALID_VALUE_ID);
33*4bdc9457SAndroid Build Coastguard Worker 
34*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_define_abs(*subgraph, input_id, output_id, /*flags=*/0));
35*4bdc9457SAndroid Build Coastguard Worker }
36*4bdc9457SAndroid Build Coastguard Worker 
37*4bdc9457SAndroid Build Coastguard Worker // Helper function to create a subgraph with 1 input, 1 output, and 1 intermediate tensor.
38*4bdc9457SAndroid Build Coastguard Worker // input -> (abs) -> intermediate -> (hard swish) -> output
39*4bdc9457SAndroid Build Coastguard Worker // The size of the tensors are all the same, specified by `dims`.
DefineGraph(xnn_subgraph_t * subgraph,std::array<size_t,4> dims)40*4bdc9457SAndroid Build Coastguard Worker void DefineGraph(xnn_subgraph_t* subgraph, std::array<size_t, 4> dims)
41*4bdc9457SAndroid Build Coastguard Worker {
42*4bdc9457SAndroid Build Coastguard Worker   xnn_create_subgraph(/*external_value_ids=*/0, /*flags=*/0, subgraph);
43*4bdc9457SAndroid Build Coastguard Worker   uint32_t input_id = XNN_INVALID_VALUE_ID;
44*4bdc9457SAndroid Build Coastguard Worker   xnn_define_tensor_value(
45*4bdc9457SAndroid Build Coastguard Worker     *subgraph, xnn_datatype_fp32, dims.size(), dims.data(), nullptr, XNN_INVALID_VALUE_ID,
46*4bdc9457SAndroid Build Coastguard Worker     XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id);
47*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(input_id, XNN_INVALID_VALUE_ID);
48*4bdc9457SAndroid Build Coastguard Worker 
49*4bdc9457SAndroid Build Coastguard Worker   uint32_t intermediate_id = XNN_INVALID_VALUE_ID;
50*4bdc9457SAndroid Build Coastguard Worker   xnn_define_tensor_value(
51*4bdc9457SAndroid Build Coastguard Worker     *subgraph, xnn_datatype_fp32, dims.size(), dims.data(), nullptr, XNN_INVALID_VALUE_ID, /*flags=*/0,
52*4bdc9457SAndroid Build Coastguard Worker     &intermediate_id);
53*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(intermediate_id, XNN_INVALID_VALUE_ID);
54*4bdc9457SAndroid Build Coastguard Worker 
55*4bdc9457SAndroid Build Coastguard Worker   uint32_t output_id = XNN_INVALID_VALUE_ID;
56*4bdc9457SAndroid Build Coastguard Worker   xnn_define_tensor_value(
57*4bdc9457SAndroid Build Coastguard Worker     *subgraph, xnn_datatype_fp32, dims.size(), dims.data(), nullptr, XNN_INVALID_VALUE_ID,
58*4bdc9457SAndroid Build Coastguard Worker     XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id);
59*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(output_id, XNN_INVALID_VALUE_ID);
60*4bdc9457SAndroid Build Coastguard Worker 
61*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_define_abs(*subgraph, input_id, intermediate_id, /*flags=*/0));
62*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_define_hardswish(*subgraph, intermediate_id, output_id, /*flags=*/0));
63*4bdc9457SAndroid Build Coastguard Worker }
64*4bdc9457SAndroid Build Coastguard Worker 
DefineGraphWithStaticData(xnn_subgraph_t * subgraph,std::array<size_t,4> dims,const std::vector<float> * static_value)65*4bdc9457SAndroid Build Coastguard Worker void DefineGraphWithStaticData(xnn_subgraph_t* subgraph, std::array<size_t, 4> dims, const std::vector<float>* static_value)
66*4bdc9457SAndroid Build Coastguard Worker {
67*4bdc9457SAndroid Build Coastguard Worker   xnn_create_subgraph(/*external_value_ids=*/0, /*flags=*/0, subgraph);
68*4bdc9457SAndroid Build Coastguard Worker   uint32_t input_id = XNN_INVALID_VALUE_ID;
69*4bdc9457SAndroid Build Coastguard Worker   xnn_define_tensor_value(
70*4bdc9457SAndroid Build Coastguard Worker     *subgraph, xnn_datatype_fp32, dims.size(), dims.data(), nullptr, XNN_INVALID_VALUE_ID,
71*4bdc9457SAndroid Build Coastguard Worker     XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id);
72*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(input_id, XNN_INVALID_VALUE_ID);
73*4bdc9457SAndroid Build Coastguard Worker 
74*4bdc9457SAndroid Build Coastguard Worker   uint32_t static_value_id = XNN_INVALID_VALUE_ID;
75*4bdc9457SAndroid Build Coastguard Worker   xnn_define_tensor_value(
76*4bdc9457SAndroid Build Coastguard Worker     *subgraph, xnn_datatype_fp32, dims.size(), dims.data(), static_value->data(), XNN_INVALID_VALUE_ID, /*flags=*/0,
77*4bdc9457SAndroid Build Coastguard Worker     &static_value_id);
78*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(static_value_id, XNN_INVALID_VALUE_ID);
79*4bdc9457SAndroid Build Coastguard Worker 
80*4bdc9457SAndroid Build Coastguard Worker   uint32_t output_id = XNN_INVALID_VALUE_ID;
81*4bdc9457SAndroid Build Coastguard Worker   xnn_define_tensor_value(
82*4bdc9457SAndroid Build Coastguard Worker     *subgraph, xnn_datatype_fp32, dims.size(), dims.data(), nullptr, XNN_INVALID_VALUE_ID,
83*4bdc9457SAndroid Build Coastguard Worker     XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id);
84*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(output_id, XNN_INVALID_VALUE_ID);
85*4bdc9457SAndroid Build Coastguard Worker 
86*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success,
87*4bdc9457SAndroid Build Coastguard Worker             xnn_define_add2(*subgraph, -std::numeric_limits<float>::infinity(),
88*4bdc9457SAndroid Build Coastguard Worker                             std::numeric_limits<float>::infinity(), input_id,
89*4bdc9457SAndroid Build Coastguard Worker                             static_value_id, output_id, /*flags=*/0));
90*4bdc9457SAndroid Build Coastguard Worker }
91*4bdc9457SAndroid Build Coastguard Worker 
BlobInWorkspace(xnn_blob * blob,xnn_workspace_t workspace)92*4bdc9457SAndroid Build Coastguard Worker testing::AssertionResult BlobInWorkspace(xnn_blob* blob, xnn_workspace_t workspace) {
93*4bdc9457SAndroid Build Coastguard Worker   if ((blob->data >= workspace->data) &&
94*4bdc9457SAndroid Build Coastguard Worker          ((uintptr_t) blob->data + blob->size) <= ((uintptr_t) workspace->data + workspace->size)) {
95*4bdc9457SAndroid Build Coastguard Worker     return testing::AssertionSuccess();
96*4bdc9457SAndroid Build Coastguard Worker   } else {
97*4bdc9457SAndroid Build Coastguard Worker     return testing::AssertionFailure()
98*4bdc9457SAndroid Build Coastguard Worker         << "blob at " << blob->data << " of size " << blob->size
99*4bdc9457SAndroid Build Coastguard Worker         << "is outside of workspace at " << workspace->data << " of size " << workspace->size;
100*4bdc9457SAndroid Build Coastguard Worker   }
101*4bdc9457SAndroid Build Coastguard Worker }
102*4bdc9457SAndroid Build Coastguard Worker 
Contains(std::vector<xnn_runtime_t> workspace_users,xnn_runtime_t runtime)103*4bdc9457SAndroid Build Coastguard Worker testing::AssertionResult Contains(std::vector<xnn_runtime_t> workspace_users, xnn_runtime_t runtime) {
104*4bdc9457SAndroid Build Coastguard Worker   if (std::find(workspace_users.begin(), workspace_users.end(), runtime) != workspace_users.end()) {
105*4bdc9457SAndroid Build Coastguard Worker     return testing::AssertionSuccess();
106*4bdc9457SAndroid Build Coastguard Worker   } else {
107*4bdc9457SAndroid Build Coastguard Worker     return testing::AssertionFailure() << "runtime " << runtime << " not found in list of workspace users";
108*4bdc9457SAndroid Build Coastguard Worker   }
109*4bdc9457SAndroid Build Coastguard Worker }
110*4bdc9457SAndroid Build Coastguard Worker 
workspace_user_to_list(xnn_workspace_t workspace)111*4bdc9457SAndroid Build Coastguard Worker std::vector<xnn_runtime_t> workspace_user_to_list(xnn_workspace_t workspace)
112*4bdc9457SAndroid Build Coastguard Worker {
113*4bdc9457SAndroid Build Coastguard Worker   std::vector<xnn_runtime_t> users;
114*4bdc9457SAndroid Build Coastguard Worker   for (xnn_runtime_t rt = workspace->first_user; rt != NULL; rt = rt->next_workspace_user) {
115*4bdc9457SAndroid Build Coastguard Worker     users.push_back(rt);
116*4bdc9457SAndroid Build Coastguard Worker   }
117*4bdc9457SAndroid Build Coastguard Worker   return users;
118*4bdc9457SAndroid Build Coastguard Worker }
119*4bdc9457SAndroid Build Coastguard Worker }  // namespace
120*4bdc9457SAndroid Build Coastguard Worker 
TEST(WORKSPACE,static_data_not_moved_does_not_segv)121*4bdc9457SAndroid Build Coastguard Worker TEST(WORKSPACE, static_data_not_moved_does_not_segv)
122*4bdc9457SAndroid Build Coastguard Worker {
123*4bdc9457SAndroid Build Coastguard Worker   std::array<size_t, 4> dims = {2, 20, 20, 3};
124*4bdc9457SAndroid Build Coastguard Worker   size_t num_elements = dims[0] * dims[1] * dims[2] * dims[3];
125*4bdc9457SAndroid Build Coastguard Worker 
126*4bdc9457SAndroid Build Coastguard Worker   xnn_initialize(/*allocator=*/nullptr);
127*4bdc9457SAndroid Build Coastguard Worker   xnn_workspace_t workspace = nullptr;
128*4bdc9457SAndroid Build Coastguard Worker   xnn_create_workspace(&workspace);
129*4bdc9457SAndroid Build Coastguard Worker   std::unique_ptr<xnn_workspace, decltype(&xnn_release_workspace)> auto_workspace(workspace, xnn_release_workspace);
130*4bdc9457SAndroid Build Coastguard Worker 
131*4bdc9457SAndroid Build Coastguard Worker   // Create a graph that with static data.
132*4bdc9457SAndroid Build Coastguard Worker   xnn_subgraph_t subgraph1 = nullptr;
133*4bdc9457SAndroid Build Coastguard Worker   std::vector<float> static_data = std::vector<float>(num_elements, 1.0f);
134*4bdc9457SAndroid Build Coastguard Worker   DefineGraphWithStaticData(&subgraph1, dims, &static_data);
135*4bdc9457SAndroid Build Coastguard Worker   std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph1(subgraph1, xnn_delete_subgraph);
136*4bdc9457SAndroid Build Coastguard Worker   xnn_runtime_t runtime1 = nullptr;
137*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_create_runtime_v4(subgraph1, nullptr, workspace, nullptr, 0, &runtime1));
138*4bdc9457SAndroid Build Coastguard Worker   std::unique_ptr<xnn_runtime, decltype(&xnn_delete_runtime)> auto_runtime1(runtime1, xnn_delete_runtime);
139*4bdc9457SAndroid Build Coastguard Worker 
140*4bdc9457SAndroid Build Coastguard Worker   // The workspace remains at size 0, without any memory allocated, since we don't have any internal tensors.
141*4bdc9457SAndroid Build Coastguard Worker   size_t old_workspace_size = workspace->size;
142*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(old_workspace_size, 0);
143*4bdc9457SAndroid Build Coastguard Worker   void* old_runtime_workspace = runtime1->workspace->data;
144*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(old_runtime_workspace, nullptr);
145*4bdc9457SAndroid Build Coastguard Worker 
146*4bdc9457SAndroid Build Coastguard Worker   // Then create a graph that has internal tensors, we will need to resize the workspace.
147*4bdc9457SAndroid Build Coastguard Worker   xnn_subgraph_t subgraph2 = nullptr;
148*4bdc9457SAndroid Build Coastguard Worker   DefineGraph(&subgraph2, dims);
149*4bdc9457SAndroid Build Coastguard Worker   std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph2(subgraph2, xnn_delete_subgraph);
150*4bdc9457SAndroid Build Coastguard Worker   xnn_runtime_t runtime2 = nullptr;
151*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_create_runtime_v4(subgraph2, nullptr, workspace, nullptr, 0, &runtime2));
152*4bdc9457SAndroid Build Coastguard Worker   std::unique_ptr<xnn_runtime, decltype(&xnn_delete_runtime)> auto_runtime2(runtime2, xnn_delete_runtime);
153*4bdc9457SAndroid Build Coastguard Worker 
154*4bdc9457SAndroid Build Coastguard Worker   // Check that the workspace grew.
155*4bdc9457SAndroid Build Coastguard Worker   ASSERT_GE(workspace->size, num_elements * sizeof(float));
156*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(runtime2->workspace->data, nullptr);
157*4bdc9457SAndroid Build Coastguard Worker 
158*4bdc9457SAndroid Build Coastguard Worker   // Try to access all the blobs and ensure that we don't segfault.
159*4bdc9457SAndroid Build Coastguard Worker   for (size_t i = 0; i < runtime1->num_blobs; i++) {
160*4bdc9457SAndroid Build Coastguard Worker     xnn_blob* blob = &runtime1->blobs[i];
161*4bdc9457SAndroid Build Coastguard Worker     if (blob->allocation_type == xnn_allocation_type_external) {
162*4bdc9457SAndroid Build Coastguard Worker       continue;
163*4bdc9457SAndroid Build Coastguard Worker     }
164*4bdc9457SAndroid Build Coastguard Worker     ASSERT_GT(blob->size, 0);
165*4bdc9457SAndroid Build Coastguard Worker     char access = *((char *)blob->data);
166*4bdc9457SAndroid Build Coastguard Worker     (void) access;
167*4bdc9457SAndroid Build Coastguard Worker   }
168*4bdc9457SAndroid Build Coastguard Worker 
169*4bdc9457SAndroid Build Coastguard Worker   for (size_t i = 0; i < runtime2->num_blobs; i++) {
170*4bdc9457SAndroid Build Coastguard Worker     xnn_blob* blob = &runtime2->blobs[i];
171*4bdc9457SAndroid Build Coastguard Worker     if (blob->allocation_type == xnn_allocation_type_external) {
172*4bdc9457SAndroid Build Coastguard Worker       continue;
173*4bdc9457SAndroid Build Coastguard Worker     }
174*4bdc9457SAndroid Build Coastguard Worker     ASSERT_GT(blob->size, 0);
175*4bdc9457SAndroid Build Coastguard Worker     char access = *((char *)blob->data);
176*4bdc9457SAndroid Build Coastguard Worker     (void) access;
177*4bdc9457SAndroid Build Coastguard Worker   }
178*4bdc9457SAndroid Build Coastguard Worker }
179*4bdc9457SAndroid Build Coastguard Worker 
TEST(WORKSPACE,workspace_no_growth)180*4bdc9457SAndroid Build Coastguard Worker TEST(WORKSPACE, workspace_no_growth)
181*4bdc9457SAndroid Build Coastguard Worker {
182*4bdc9457SAndroid Build Coastguard Worker   xnn_initialize(/*allocator=*/nullptr);
183*4bdc9457SAndroid Build Coastguard Worker   xnn_workspace_t workspace = nullptr;
184*4bdc9457SAndroid Build Coastguard Worker   xnn_create_workspace(&workspace);
185*4bdc9457SAndroid Build Coastguard Worker   std::unique_ptr<xnn_workspace, decltype(&xnn_release_workspace)> auto_workspace(workspace, xnn_release_workspace);
186*4bdc9457SAndroid Build Coastguard Worker 
187*4bdc9457SAndroid Build Coastguard Worker   std::array<size_t, 4> dims = {2, 20, 20, 3};
188*4bdc9457SAndroid Build Coastguard Worker 
189*4bdc9457SAndroid Build Coastguard Worker   xnn_subgraph_t subgraph1 = nullptr;
190*4bdc9457SAndroid Build Coastguard Worker   DefineGraph(&subgraph1, dims);
191*4bdc9457SAndroid Build Coastguard Worker   std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph1(subgraph1, xnn_delete_subgraph);
192*4bdc9457SAndroid Build Coastguard Worker 
193*4bdc9457SAndroid Build Coastguard Worker   xnn_runtime_t runtime1 = nullptr;
194*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_create_runtime_v4(subgraph1, nullptr, workspace, nullptr, 0, &runtime1));
195*4bdc9457SAndroid Build Coastguard Worker   std::unique_ptr<xnn_runtime, decltype(&xnn_delete_runtime)> auto_runtime1(runtime1, xnn_delete_runtime);
196*4bdc9457SAndroid Build Coastguard Worker 
197*4bdc9457SAndroid Build Coastguard Worker   size_t old_workspace_size = workspace->size;
198*4bdc9457SAndroid Build Coastguard Worker   ASSERT_GE(old_workspace_size, 0);
199*4bdc9457SAndroid Build Coastguard Worker   void* old_runtime_workspace = runtime1->workspace->data;
200*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(old_runtime_workspace, nullptr);
201*4bdc9457SAndroid Build Coastguard Worker 
202*4bdc9457SAndroid Build Coastguard Worker   // Create the same graph again with a different runtime that shares the workspace.
203*4bdc9457SAndroid Build Coastguard Worker   xnn_subgraph_t subgraph2 = nullptr;
204*4bdc9457SAndroid Build Coastguard Worker   DefineGraph(&subgraph2, dims);
205*4bdc9457SAndroid Build Coastguard Worker   std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph2(subgraph2, xnn_delete_subgraph);
206*4bdc9457SAndroid Build Coastguard Worker 
207*4bdc9457SAndroid Build Coastguard Worker   xnn_runtime_t runtime2 = nullptr;
208*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_create_runtime_v4(subgraph2, nullptr, workspace, nullptr, 0, &runtime2));
209*4bdc9457SAndroid Build Coastguard Worker   std::unique_ptr<xnn_runtime, decltype(&xnn_delete_runtime)> auto_runtime2(runtime2, xnn_delete_runtime);
210*4bdc9457SAndroid Build Coastguard Worker 
211*4bdc9457SAndroid Build Coastguard Worker   // Check that the workspace did not grow.
212*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(workspace->size, old_workspace_size);
213*4bdc9457SAndroid Build Coastguard Worker   // Check that runtime 2 uses the same workspace.
214*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(runtime2->workspace->data, old_runtime_workspace);
215*4bdc9457SAndroid Build Coastguard Worker 
216*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(runtime1->num_blobs, runtime2->num_blobs);
217*4bdc9457SAndroid Build Coastguard Worker   for (size_t i = 0; i < runtime1->num_blobs; i++) {
218*4bdc9457SAndroid Build Coastguard Worker     xnn_blob* blob1 = &runtime1->blobs[i];
219*4bdc9457SAndroid Build Coastguard Worker     if (blob1->allocation_type != xnn_allocation_type_workspace) {
220*4bdc9457SAndroid Build Coastguard Worker       continue;
221*4bdc9457SAndroid Build Coastguard Worker     }
222*4bdc9457SAndroid Build Coastguard Worker     ASSERT_TRUE(BlobInWorkspace(blob1, runtime1->workspace));
223*4bdc9457SAndroid Build Coastguard Worker     xnn_blob* blob2 = &runtime2->blobs[i];
224*4bdc9457SAndroid Build Coastguard Worker     ASSERT_TRUE(BlobInWorkspace(blob2, runtime2->workspace));
225*4bdc9457SAndroid Build Coastguard Worker   }
226*4bdc9457SAndroid Build Coastguard Worker 
227*4bdc9457SAndroid Build Coastguard Worker   std::vector<xnn_runtime_t> workspace_users = workspace_user_to_list(workspace);
228*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(workspace_users.size(), 2);
229*4bdc9457SAndroid Build Coastguard Worker   ASSERT_TRUE(Contains(workspace_users, runtime1));
230*4bdc9457SAndroid Build Coastguard Worker   ASSERT_TRUE(Contains(workspace_users, runtime2));
231*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(workspace->ref_count, 3);
232*4bdc9457SAndroid Build Coastguard Worker }
233*4bdc9457SAndroid Build Coastguard Worker 
TEST(WORKSPACE,workspace_grow)234*4bdc9457SAndroid Build Coastguard Worker TEST(WORKSPACE, workspace_grow)
235*4bdc9457SAndroid Build Coastguard Worker {
236*4bdc9457SAndroid Build Coastguard Worker   xnn_initialize(/*allocator=*/nullptr);
237*4bdc9457SAndroid Build Coastguard Worker   xnn_workspace_t workspace = nullptr;
238*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_create_workspace(&workspace));
239*4bdc9457SAndroid Build Coastguard Worker   std::unique_ptr<xnn_workspace, decltype(&xnn_release_workspace)> auto_workspace(workspace, xnn_release_workspace);
240*4bdc9457SAndroid Build Coastguard Worker 
241*4bdc9457SAndroid Build Coastguard Worker   std::array<size_t, 4> dims1 = {2, 20, 20, 3};
242*4bdc9457SAndroid Build Coastguard Worker 
243*4bdc9457SAndroid Build Coastguard Worker   xnn_subgraph_t subgraph1 = nullptr;
244*4bdc9457SAndroid Build Coastguard Worker   DefineGraph(&subgraph1, dims1);
245*4bdc9457SAndroid Build Coastguard Worker   std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph1(subgraph1, xnn_delete_subgraph);
246*4bdc9457SAndroid Build Coastguard Worker 
247*4bdc9457SAndroid Build Coastguard Worker   xnn_runtime_t runtime1 = nullptr;
248*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_create_runtime_v4(subgraph1, nullptr, workspace, nullptr, 0, &runtime1));
249*4bdc9457SAndroid Build Coastguard Worker   std::unique_ptr<xnn_runtime, decltype(&xnn_delete_runtime)> auto_runtime1(runtime1, xnn_delete_runtime);
250*4bdc9457SAndroid Build Coastguard Worker 
251*4bdc9457SAndroid Build Coastguard Worker   size_t old_workspace_size = workspace->size;
252*4bdc9457SAndroid Build Coastguard Worker   ASSERT_GE(old_workspace_size, 0);
253*4bdc9457SAndroid Build Coastguard Worker   void* old_runtime_workspace = runtime1->workspace->data;
254*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(old_runtime_workspace, nullptr);
255*4bdc9457SAndroid Build Coastguard Worker 
256*4bdc9457SAndroid Build Coastguard Worker   std::array<size_t, 4> dims2 = dims1;
257*4bdc9457SAndroid Build Coastguard Worker   // Create the same graph but with larger tensors, this will require a larger workspace.
258*4bdc9457SAndroid Build Coastguard Worker   std::transform(dims2.begin(), dims2.end(), dims2.begin(), [](size_t i) { return i * 2; });
259*4bdc9457SAndroid Build Coastguard Worker   xnn_subgraph_t subgraph2 = nullptr;
260*4bdc9457SAndroid Build Coastguard Worker   DefineGraph(&subgraph2, dims2);
261*4bdc9457SAndroid Build Coastguard Worker   std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph2(subgraph2, xnn_delete_subgraph);
262*4bdc9457SAndroid Build Coastguard Worker 
263*4bdc9457SAndroid Build Coastguard Worker   xnn_runtime_t runtime2 = nullptr;
264*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_create_runtime_v4(subgraph2, nullptr, workspace, nullptr, 0, &runtime2));
265*4bdc9457SAndroid Build Coastguard Worker   std::unique_ptr<xnn_runtime, decltype(&xnn_delete_runtime)> auto_runtime2(runtime2, xnn_delete_runtime);
266*4bdc9457SAndroid Build Coastguard Worker 
267*4bdc9457SAndroid Build Coastguard Worker   // Check that the workspace grew.
268*4bdc9457SAndroid Build Coastguard Worker   ASSERT_GE(workspace->size, old_workspace_size);
269*4bdc9457SAndroid Build Coastguard Worker   // Check that runtime 2 uses the same workspace.
270*4bdc9457SAndroid Build Coastguard Worker   ASSERT_NE(runtime2->workspace->data, old_runtime_workspace);
271*4bdc9457SAndroid Build Coastguard Worker   // Check that runtime1's workspace has been updated as well.
272*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(runtime1->workspace->data, runtime2->workspace->data);
273*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(runtime1->workspace->size, runtime2->workspace->size);
274*4bdc9457SAndroid Build Coastguard Worker 
275*4bdc9457SAndroid Build Coastguard Worker   // Check that both runtime's blob pointers are within range.
276*4bdc9457SAndroid Build Coastguard Worker   for (size_t i = 0; i < runtime1->num_blobs; i++) {
277*4bdc9457SAndroid Build Coastguard Worker     xnn_blob* blob = &runtime1->blobs[i];
278*4bdc9457SAndroid Build Coastguard Worker     if (blob->allocation_type != xnn_allocation_type_workspace) {
279*4bdc9457SAndroid Build Coastguard Worker       continue;
280*4bdc9457SAndroid Build Coastguard Worker     }
281*4bdc9457SAndroid Build Coastguard Worker     ASSERT_TRUE(BlobInWorkspace(blob, runtime1->workspace));
282*4bdc9457SAndroid Build Coastguard Worker   }
283*4bdc9457SAndroid Build Coastguard Worker   for (size_t i = 0; i < runtime2->num_blobs; i++) {
284*4bdc9457SAndroid Build Coastguard Worker     xnn_blob* blob = &runtime2->blobs[i];
285*4bdc9457SAndroid Build Coastguard Worker     if (blob->allocation_type != xnn_allocation_type_workspace) {
286*4bdc9457SAndroid Build Coastguard Worker       continue;
287*4bdc9457SAndroid Build Coastguard Worker     }
288*4bdc9457SAndroid Build Coastguard Worker     ASSERT_TRUE(BlobInWorkspace(blob, runtime2->workspace));
289*4bdc9457SAndroid Build Coastguard Worker   }
290*4bdc9457SAndroid Build Coastguard Worker 
291*4bdc9457SAndroid Build Coastguard Worker   std::vector<xnn_runtime_t> workspace_users = workspace_user_to_list(workspace);
292*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(workspace_users.size(), 2);
293*4bdc9457SAndroid Build Coastguard Worker   ASSERT_TRUE(Contains(workspace_users, runtime1));
294*4bdc9457SAndroid Build Coastguard Worker   ASSERT_TRUE(Contains(workspace_users, runtime2));
295*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(workspace->ref_count, 3);
296*4bdc9457SAndroid Build Coastguard Worker }
297*4bdc9457SAndroid Build Coastguard Worker 
TEST(WORKSPACE,workspace_runtime_delete_head_runtime_first)298*4bdc9457SAndroid Build Coastguard Worker TEST(WORKSPACE, workspace_runtime_delete_head_runtime_first)
299*4bdc9457SAndroid Build Coastguard Worker {
300*4bdc9457SAndroid Build Coastguard Worker   xnn_initialize(/*allocator=*/nullptr);
301*4bdc9457SAndroid Build Coastguard Worker   xnn_workspace_t workspace = nullptr;
302*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_create_workspace(&workspace));
303*4bdc9457SAndroid Build Coastguard Worker   std::unique_ptr<xnn_workspace, decltype(&xnn_release_workspace)> auto_workspace(workspace, xnn_release_workspace);
304*4bdc9457SAndroid Build Coastguard Worker 
305*4bdc9457SAndroid Build Coastguard Worker   const std::array<size_t, 4> dims = {2, 20, 20, 3};
306*4bdc9457SAndroid Build Coastguard Worker 
307*4bdc9457SAndroid Build Coastguard Worker   xnn_subgraph_t subgraph1 = nullptr;
308*4bdc9457SAndroid Build Coastguard Worker   DefineGraph(&subgraph1, dims);
309*4bdc9457SAndroid Build Coastguard Worker   std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph1(subgraph1, xnn_delete_subgraph);
310*4bdc9457SAndroid Build Coastguard Worker 
311*4bdc9457SAndroid Build Coastguard Worker   xnn_runtime_t runtime1 = nullptr;
312*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_create_runtime_v4(subgraph1, nullptr, workspace, nullptr, 0, &runtime1));
313*4bdc9457SAndroid Build Coastguard Worker   std::unique_ptr<xnn_runtime, decltype(&xnn_delete_runtime)> auto_runtime1(runtime1, xnn_delete_runtime);
314*4bdc9457SAndroid Build Coastguard Worker 
315*4bdc9457SAndroid Build Coastguard Worker   xnn_subgraph_t subgraph2 = nullptr;
316*4bdc9457SAndroid Build Coastguard Worker   DefineGraph(&subgraph2, dims);
317*4bdc9457SAndroid Build Coastguard Worker   std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph2(subgraph2, xnn_delete_subgraph);
318*4bdc9457SAndroid Build Coastguard Worker 
319*4bdc9457SAndroid Build Coastguard Worker   xnn_runtime_t runtime2 = nullptr;
320*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_create_runtime_v4(subgraph2, nullptr, workspace, nullptr, 0, &runtime2));
321*4bdc9457SAndroid Build Coastguard Worker   std::unique_ptr<xnn_runtime, decltype(&xnn_delete_runtime)> auto_runtime2(runtime2, xnn_delete_runtime);
322*4bdc9457SAndroid Build Coastguard Worker 
323*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(workspace->first_user, runtime2);
324*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(runtime2->next_workspace_user, runtime1);
325*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(runtime1->next_workspace_user, nullptr);
326*4bdc9457SAndroid Build Coastguard Worker 
327*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(workspace->ref_count, 3);
328*4bdc9457SAndroid Build Coastguard Worker   xnn_delete_runtime(auto_runtime2.release());
329*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(workspace->first_user, runtime1);
330*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(runtime1->next_workspace_user, nullptr);
331*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(workspace->ref_count, 2);
332*4bdc9457SAndroid Build Coastguard Worker 
333*4bdc9457SAndroid Build Coastguard Worker   xnn_delete_runtime(auto_runtime1.release());
334*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(workspace->first_user, nullptr);
335*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(workspace->ref_count, 1);
336*4bdc9457SAndroid Build Coastguard Worker }
337*4bdc9457SAndroid Build Coastguard Worker 
TEST(WORKSPACE,workspace_runtime_delete_tail_runtime_first)338*4bdc9457SAndroid Build Coastguard Worker TEST(WORKSPACE, workspace_runtime_delete_tail_runtime_first)
339*4bdc9457SAndroid Build Coastguard Worker {
340*4bdc9457SAndroid Build Coastguard Worker   xnn_initialize(/*allocator=*/nullptr);
341*4bdc9457SAndroid Build Coastguard Worker   xnn_workspace_t workspace = nullptr;
342*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_create_workspace(&workspace));
343*4bdc9457SAndroid Build Coastguard Worker   std::unique_ptr<xnn_workspace, decltype(&xnn_release_workspace)> auto_workspace(workspace, xnn_release_workspace);
344*4bdc9457SAndroid Build Coastguard Worker 
345*4bdc9457SAndroid Build Coastguard Worker   std::array<size_t, 4> dims = {2, 20, 20, 3};
346*4bdc9457SAndroid Build Coastguard Worker 
347*4bdc9457SAndroid Build Coastguard Worker   xnn_subgraph_t subgraph1 = nullptr;
348*4bdc9457SAndroid Build Coastguard Worker   DefineGraph(&subgraph1, dims);
349*4bdc9457SAndroid Build Coastguard Worker   std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph1(subgraph1, xnn_delete_subgraph);
350*4bdc9457SAndroid Build Coastguard Worker 
351*4bdc9457SAndroid Build Coastguard Worker   xnn_runtime_t runtime1 = nullptr;
352*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_create_runtime_v4(subgraph1, nullptr, workspace, nullptr, 0, &runtime1));
353*4bdc9457SAndroid Build Coastguard Worker   std::unique_ptr<xnn_runtime, decltype(&xnn_delete_runtime)> auto_runtime1(runtime1, xnn_delete_runtime);
354*4bdc9457SAndroid Build Coastguard Worker 
355*4bdc9457SAndroid Build Coastguard Worker   xnn_subgraph_t subgraph2 = nullptr;
356*4bdc9457SAndroid Build Coastguard Worker   DefineGraph(&subgraph2, dims);
357*4bdc9457SAndroid Build Coastguard Worker   std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph2(subgraph2, xnn_delete_subgraph);
358*4bdc9457SAndroid Build Coastguard Worker 
359*4bdc9457SAndroid Build Coastguard Worker   xnn_runtime_t runtime2 = nullptr;
360*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_create_runtime_v4(subgraph2, nullptr, workspace, nullptr, 0, &runtime2));
361*4bdc9457SAndroid Build Coastguard Worker   std::unique_ptr<xnn_runtime, decltype(&xnn_delete_runtime)> auto_runtime2(runtime2, xnn_delete_runtime);
362*4bdc9457SAndroid Build Coastguard Worker 
363*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(workspace->first_user, runtime2);
364*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(runtime2->next_workspace_user, runtime1);
365*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(runtime1->next_workspace_user, nullptr);
366*4bdc9457SAndroid Build Coastguard Worker 
367*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(workspace->ref_count, 3);
368*4bdc9457SAndroid Build Coastguard Worker   xnn_delete_runtime(auto_runtime1.release());
369*4bdc9457SAndroid Build Coastguard Worker 
370*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(workspace->first_user, runtime2);
371*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(runtime2->next_workspace_user, nullptr);
372*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(workspace->ref_count, 2);
373*4bdc9457SAndroid Build Coastguard Worker 
374*4bdc9457SAndroid Build Coastguard Worker   xnn_delete_runtime(auto_runtime2.release());
375*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(workspace->first_user, nullptr);
376*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(workspace->ref_count, 1);
377*4bdc9457SAndroid Build Coastguard Worker }
378*4bdc9457SAndroid Build Coastguard Worker 
TEST(WORKSPACE,workspace_runtime_delete_middle_runtime_first)379*4bdc9457SAndroid Build Coastguard Worker TEST(WORKSPACE, workspace_runtime_delete_middle_runtime_first)
380*4bdc9457SAndroid Build Coastguard Worker {
381*4bdc9457SAndroid Build Coastguard Worker   xnn_initialize(/*allocator=*/nullptr);
382*4bdc9457SAndroid Build Coastguard Worker   xnn_workspace_t workspace = nullptr;
383*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_create_workspace(&workspace));
384*4bdc9457SAndroid Build Coastguard Worker   std::unique_ptr<xnn_workspace, decltype(&xnn_release_workspace)> auto_workspace(workspace, xnn_release_workspace);
385*4bdc9457SAndroid Build Coastguard Worker 
386*4bdc9457SAndroid Build Coastguard Worker   std::array<size_t, 4> dims = {2, 20, 20, 3};
387*4bdc9457SAndroid Build Coastguard Worker 
388*4bdc9457SAndroid Build Coastguard Worker   xnn_subgraph_t subgraph1 = nullptr;
389*4bdc9457SAndroid Build Coastguard Worker   DefineGraph(&subgraph1, dims);
390*4bdc9457SAndroid Build Coastguard Worker   std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph1(subgraph1, xnn_delete_subgraph);
391*4bdc9457SAndroid Build Coastguard Worker 
392*4bdc9457SAndroid Build Coastguard Worker   xnn_runtime_t runtime1 = nullptr;
393*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_create_runtime_v4(subgraph1, nullptr, workspace, nullptr, 0, &runtime1));
394*4bdc9457SAndroid Build Coastguard Worker   std::unique_ptr<xnn_runtime, decltype(&xnn_delete_runtime)> auto_runtime1(runtime1, xnn_delete_runtime);
395*4bdc9457SAndroid Build Coastguard Worker 
396*4bdc9457SAndroid Build Coastguard Worker   xnn_subgraph_t subgraph2 = nullptr;
397*4bdc9457SAndroid Build Coastguard Worker   DefineGraph(&subgraph2, dims);
398*4bdc9457SAndroid Build Coastguard Worker   std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph2(subgraph2, xnn_delete_subgraph);
399*4bdc9457SAndroid Build Coastguard Worker 
400*4bdc9457SAndroid Build Coastguard Worker   xnn_runtime_t runtime2 = nullptr;
401*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_create_runtime_v4(subgraph2, nullptr, workspace, nullptr, 0, &runtime2));
402*4bdc9457SAndroid Build Coastguard Worker   std::unique_ptr<xnn_runtime, decltype(&xnn_delete_runtime)> auto_runtime2(runtime2, xnn_delete_runtime);
403*4bdc9457SAndroid Build Coastguard Worker 
404*4bdc9457SAndroid Build Coastguard Worker   xnn_subgraph_t subgraph3 = nullptr;
405*4bdc9457SAndroid Build Coastguard Worker   DefineGraph(&subgraph3, dims);
406*4bdc9457SAndroid Build Coastguard Worker   std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph3(subgraph3, xnn_delete_subgraph);
407*4bdc9457SAndroid Build Coastguard Worker 
408*4bdc9457SAndroid Build Coastguard Worker   xnn_runtime_t runtime3 = nullptr;
409*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_create_runtime_v4(subgraph3, nullptr, workspace, nullptr, 0, &runtime3));
410*4bdc9457SAndroid Build Coastguard Worker   std::unique_ptr<xnn_runtime, decltype(&xnn_delete_runtime)> auto_runtime3(runtime3, xnn_delete_runtime);
411*4bdc9457SAndroid Build Coastguard Worker 
412*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(workspace->first_user, runtime3);
413*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(runtime3->next_workspace_user, runtime2);
414*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(runtime2->next_workspace_user, runtime1);
415*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(runtime1->next_workspace_user, nullptr);
416*4bdc9457SAndroid Build Coastguard Worker 
417*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(workspace->ref_count, 4);
418*4bdc9457SAndroid Build Coastguard Worker   xnn_delete_runtime(auto_runtime2.release());
419*4bdc9457SAndroid Build Coastguard Worker 
420*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(workspace->first_user, runtime3);
421*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(runtime3->next_workspace_user, runtime1);
422*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(runtime1->next_workspace_user, nullptr);
423*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(workspace->ref_count, 3);
424*4bdc9457SAndroid Build Coastguard Worker 
425*4bdc9457SAndroid Build Coastguard Worker   xnn_delete_runtime(auto_runtime3.release());
426*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(workspace->first_user, runtime1);
427*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(runtime1->next_workspace_user, nullptr);
428*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(workspace->ref_count, 2);
429*4bdc9457SAndroid Build Coastguard Worker 
430*4bdc9457SAndroid Build Coastguard Worker   xnn_delete_runtime(auto_runtime1.release());
431*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(workspace->first_user, nullptr);
432*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(workspace->ref_count, 1);
433*4bdc9457SAndroid Build Coastguard Worker }
434*4bdc9457SAndroid Build Coastguard Worker 
TEST(WORKSPACE,zero_sized_workspace_for_graph_without_internal_tensors)435*4bdc9457SAndroid Build Coastguard Worker TEST(WORKSPACE, zero_sized_workspace_for_graph_without_internal_tensors)
436*4bdc9457SAndroid Build Coastguard Worker {
437*4bdc9457SAndroid Build Coastguard Worker   xnn_initialize(/*allocator=*/nullptr);
438*4bdc9457SAndroid Build Coastguard Worker   xnn_workspace_t workspace = nullptr;
439*4bdc9457SAndroid Build Coastguard Worker   xnn_create_workspace(&workspace);
440*4bdc9457SAndroid Build Coastguard Worker   std::unique_ptr<xnn_workspace, decltype(&xnn_release_workspace)> auto_workspace(workspace, xnn_release_workspace);
441*4bdc9457SAndroid Build Coastguard Worker 
442*4bdc9457SAndroid Build Coastguard Worker   std::array<size_t, 4> dims = {2, 20, 20, 3};
443*4bdc9457SAndroid Build Coastguard Worker 
444*4bdc9457SAndroid Build Coastguard Worker   xnn_subgraph_t subgraph = nullptr;
445*4bdc9457SAndroid Build Coastguard Worker   DefineGraphWithoutInternalTensors(&subgraph, dims);
446*4bdc9457SAndroid Build Coastguard Worker   std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> auto_subgraph(subgraph, xnn_delete_subgraph);
447*4bdc9457SAndroid Build Coastguard Worker 
448*4bdc9457SAndroid Build Coastguard Worker   xnn_runtime_t runtime = nullptr;
449*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(xnn_status_success, xnn_create_runtime_v4(subgraph, nullptr, workspace, nullptr, 0, &runtime));
450*4bdc9457SAndroid Build Coastguard Worker   std::unique_ptr<xnn_runtime, decltype(&xnn_delete_runtime)> auto_runtime(runtime, xnn_delete_runtime);
451*4bdc9457SAndroid Build Coastguard Worker 
452*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(0, workspace->size);
453*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(nullptr, workspace->data);
454*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(std::vector<xnn_runtime_t>({runtime}), workspace_user_to_list(workspace));
455*4bdc9457SAndroid Build Coastguard Worker   ASSERT_EQ(workspace->ref_count, 2);
456*4bdc9457SAndroid Build Coastguard Worker }
457