xref: /aosp_15_r20/external/armnn/src/backends/backendsCommon/test/layerTests/BatchMatMulTestImpl.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include <armnnTestUtils/LayerTestResult.hpp>
9 
10 #include <ResolveType.hpp>
11 
12 #include <armnn/backends/IBackendInternal.hpp>
13 
14 template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>, std::size_t NumDims>
15 LayerTestResult<T, NumDims> BatchMatMulTestImpl(
16     armnn::IWorkloadFactory& workloadFactory,
17     const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
18     const armnn::ITensorHandleFactory& tensorHandleFactory,
19     armnn::BatchMatMulDescriptor descriptor,
20     const std::vector<T>& inputX,
21     const std::vector<T>& inputY,
22     const std::vector<T>& outputExpected,
23     const armnn::TensorInfo& inputXInfo,
24     const armnn::TensorInfo& inputYInfo,
25     const armnn::TensorInfo& outputInfo);
26 
27 template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
28 LayerTestResult<T, 2> BatchMatMul2DSimpleTest(
29     armnn::IWorkloadFactory& workloadFactory,
30     const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
31     const armnn::ITensorHandleFactory& tensorHandleFactory);
32 
33 template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
34 LayerTestResult<T, 3> BatchMatMul3DSimpleTest(
35     armnn::IWorkloadFactory& workloadFactory,
36     const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
37     const armnn::ITensorHandleFactory& tensorHandleFactory);
38 
39 template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
40 LayerTestResult<T, 4> BatchMatMulNCHWSimpleTest(
41     armnn::IWorkloadFactory& workloadFactory,
42     const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
43     const armnn::ITensorHandleFactory& tensorHandleFactory);
44 
45 template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
46 LayerTestResult<T, 4> BatchMatMulNHWCSimpleTest(
47     armnn::IWorkloadFactory& workloadFactory,
48     const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
49     const armnn::ITensorHandleFactory& tensorHandleFactory);
50 
51 template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
52 LayerTestResult<T, 3> BatchMatMul3DBatchTest(
53     armnn::IWorkloadFactory& workloadFactory,
54     const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
55     const armnn::ITensorHandleFactory& tensorHandleFactory);
56 
57 template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
58 LayerTestResult<T, 3> BatchMatMul3DBroadcastTest(
59     armnn::IWorkloadFactory& workloadFactory,
60     const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
61     const armnn::ITensorHandleFactory& tensorHandleFactory);
62 
63 template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
64 LayerTestResult<T, 3> BatchMatMul3D2DBroadcastTest(
65     armnn::IWorkloadFactory& workloadFactory,
66     const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
67     const armnn::ITensorHandleFactory& tensorHandleFactory);
68 
69 template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
70 LayerTestResult<T, 5> BatchMatMulNDHWCNHWCTest(
71     armnn::IWorkloadFactory& workloadFactory,
72     const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
73     const armnn::ITensorHandleFactory& tensorHandleFactory);
74 
75 template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
76 LayerTestResult<T, 2> BatchMatMul2DTinyTest(
77     armnn::IWorkloadFactory& workloadFactory,
78     const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
79     const armnn::ITensorHandleFactory& tensorHandleFactory);
80 
81 template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
82 LayerTestResult<T, 3> BatchMatMul3DNonSquareTest(
83     armnn::IWorkloadFactory& workloadFactory,
84     const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
85     const armnn::ITensorHandleFactory& tensorHandleFactory);
86 
87 template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
88 LayerTestResult<T, 2> BatchMatMul2DTranspSimpleTest(
89     armnn::IWorkloadFactory& workloadFactory,
90     const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
91     const armnn::ITensorHandleFactory& tensorHandleFactory);
92 
93 template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
94 LayerTestResult<T, 2> BatchMatMul2DAdjointSimpleTest(
95     armnn::IWorkloadFactory& workloadFactory,
96     const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
97     const armnn::ITensorHandleFactory& tensorHandleFactory);
98 
99 template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
100 LayerTestResult<T, 4> BatchMatMulNHWCParamsTest(
101     armnn::IWorkloadFactory& workloadFactory,
102     const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
103     const armnn::ITensorHandleFactory& tensorHandleFactory);