xref: /aosp_15_r20/external/armnn/src/backends/cl/test/ClBackendTests.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include <backendsCommon/TensorHandleFactoryRegistry.hpp>
7 #include <cl/ClBackend.hpp>
8 #include <cl/ClTensorHandleFactory.hpp>
9 #include <cl/ClImportTensorHandleFactory.hpp>
10 #include <cl/test/ClContextControlFixture.hpp>
11 
12 #include <doctest/doctest.h>
13 
14 using namespace armnn;
15 
16 TEST_SUITE("ClBackendTests")
17 {
18 TEST_CASE("ClRegisterTensorHandleFactoriesMatchingImportFactoryId")
19 {
20     auto clBackend = std::make_unique<ClBackend>();
21     TensorHandleFactoryRegistry registry;
22     clBackend->RegisterTensorHandleFactories(registry);
23 
24     // When calling RegisterTensorHandleFactories, CopyAndImportFactoryPair is registered
25     // Get ClImportTensorHandleFactory id as the matching import factory id
26     CHECK((registry.GetMatchingImportFactoryId(ClTensorHandleFactory::GetIdStatic()) ==
27            ClImportTensorHandleFactory::GetIdStatic()));
28 }
29 
30 TEST_CASE("ClRegisterTensorHandleFactoriesWithMemorySourceFlagsMatchingImportFactoryId")
31 {
32     auto clBackend = std::make_unique<ClBackend>();
33     TensorHandleFactoryRegistry registry;
34     clBackend->RegisterTensorHandleFactories(registry,
35                                              static_cast<MemorySourceFlags>(MemorySource::Malloc),
36                                              static_cast<MemorySourceFlags>(MemorySource::Malloc));
37 
38     // When calling RegisterTensorHandleFactories with MemorySourceFlags, CopyAndImportFactoryPair is registered
39     // Get ClImportTensorHandleFactory id as the matching import factory id
40     CHECK((registry.GetMatchingImportFactoryId(ClTensorHandleFactory::GetIdStatic()) ==
41            ClImportTensorHandleFactory::GetIdStatic()));
42 }
43 
44 TEST_CASE_FIXTURE(ClContextControlFixture, "ClCreateWorkloadFactoryMatchingImportFactoryId")
45 {
46     auto clBackend = std::make_unique<ClBackend>();
47     TensorHandleFactoryRegistry registry;
48     clBackend->CreateWorkloadFactory(registry);
49 
50     // When calling CreateWorkloadFactory, CopyAndImportFactoryPair is registered
51     // Get ClImportTensorHandleFactory id as the matching import factory id
52     CHECK((registry.GetMatchingImportFactoryId(ClTensorHandleFactory::GetIdStatic()) ==
53            ClImportTensorHandleFactory::GetIdStatic()));
54 }
55 
56 TEST_CASE_FIXTURE(ClContextControlFixture, "ClCreateWorkloadFactoryWithOptionsMatchingImportFactoryId")
57 {
58     auto clBackend = std::make_unique<ClBackend>();
59     TensorHandleFactoryRegistry registry;
60     ModelOptions modelOptions;
61     clBackend->CreateWorkloadFactory(registry, modelOptions);
62 
63     // When calling CreateWorkloadFactory with ModelOptions, CopyAndImportFactoryPair is registered
64     // Get ClImportTensorHandleFactory id as the matching import factory id
65     CHECK((registry.GetMatchingImportFactoryId(ClTensorHandleFactory::GetIdStatic()) ==
66            ClImportTensorHandleFactory::GetIdStatic()));
67 }
68 
69 TEST_CASE_FIXTURE(ClContextControlFixture, "ClCreateWorkloadFactoryWitMemoryFlagsMatchingImportFactoryId")
70 {
71     auto clBackend = std::make_unique<ClBackend>();
72     TensorHandleFactoryRegistry registry;
73     ModelOptions modelOptions;
74     clBackend->CreateWorkloadFactory(registry, modelOptions,
75                                      static_cast<MemorySourceFlags>(MemorySource::Malloc),
76                                      static_cast<MemorySourceFlags>(MemorySource::Malloc));
77 
78     // When calling CreateWorkloadFactory with ModelOptions and MemorySourceFlags,
79     // CopyAndImportFactoryPair is registered
80     // Get ClImportTensorHandleFactory id as the matching import factory id
81     CHECK((registry.GetMatchingImportFactoryId(ClTensorHandleFactory::GetIdStatic()) ==
82            ClImportTensorHandleFactory::GetIdStatic()));
83 }
84 }
85